Skip to content

Commit 240ae7a

Browse files
committed
Fix
1 parent 281c1ca commit 240ae7a

File tree

18 files changed

+148
-83
lines changed

18 files changed

+148
-83
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ class Analyzer(
856856
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
857857

858858
// To resolve duplicate expression IDs for Join and Intersect
859-
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
859+
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
860860
j.copy(right = dedupRight(left, right))
861861
case i @ Intersect(left, right) if !i.duplicateResolved =>
862862
i.copy(right = dedupRight(left, right))
@@ -2087,10 +2087,10 @@ class Analyzer(
20872087
*/
20882088
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
20892089
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
2090-
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
2090+
case j @ Join(left, right, UsingJoin(joinType, usingCols), _, _)
20912091
if left.resolved && right.resolved && j.duplicateResolved =>
20922092
commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
2093-
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
2093+
case j @ Join(left, right, NaturalJoin(joinType), condition, _) if j.resolvedExceptNatural =>
20942094
// find common column names from both sides
20952095
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
20962096
commonNaturalJoinProcessing(left, right, joinType, joinNames, condition)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ trait CheckAnalysis extends PredicateHelper {
147147
failAnalysis("Null-aware predicate sub-queries cannot be used in nested " +
148148
s"conditions: $condition")
149149

150-
case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
150+
case j @ Join(_, _, _, Some(condition), _) if condition.dataType != BooleanType =>
151151
failAnalysis(
152152
s"join condition '${condition.sql}' " +
153153
s"of type ${condition.dataType.simpleString} is not a boolean.")
@@ -583,7 +583,7 @@ trait CheckAnalysis extends PredicateHelper {
583583
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
584584

585585
// Join can host correlated expressions.
586-
case j @ Join(left, right, joinType, _) =>
586+
case j @ Join(left, right, joinType, _, _) =>
587587
joinType match {
588588
// Inner join, like Filter, can be anywhere.
589589
case _: InnerLike =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ object UnsupportedOperationChecker {
228228
throwError("dropDuplicates is not supported after aggregation on a " +
229229
"streaming DataFrame/Dataset")
230230

231-
case Join(left, right, joinType, condition) =>
231+
case Join(left, right, joinType, condition, _) =>
232232

233233
joinType match {
234234

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ trait PredicateHelper {
6363
}
6464
}
6565

66+
// If one expression and its children are null intolerant, it is null intolerant.
67+
protected def isNullIntolerant(expr: Expression): Boolean = expr match {
68+
case e: NullIntolerant => e.children.forall(isNullIntolerant)
69+
case _ => false
70+
}
71+
6672
// Substitute any known alias from a map.
6773
protected def replaceAlias(
6874
condition: Expression,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
4242
} else {
4343
val result = plan transformDown {
4444
// Start reordering with a joinable item, which is an InnerLike join with conditions.
45-
case j @ Join(_, _, _: InnerLike, Some(cond)) =>
45+
case j @ Join(_, _, _: InnerLike, Some(cond), _) =>
4646
reorder(j, j.output)
47-
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond)))
47+
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _))
4848
if projectList.forall(_.isInstanceOf[Attribute]) =>
4949
reorder(p, p.output)
5050
}
@@ -76,12 +76,12 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
7676
*/
7777
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
7878
plan match {
79-
case Join(left, right, _: InnerLike, Some(cond)) =>
79+
case Join(left, right, _: InnerLike, Some(cond), _) =>
8080
val (leftPlans, leftConditions) = extractInnerJoins(left)
8181
val (rightPlans, rightConditions) = extractInnerJoins(right)
8282
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
8383
leftConditions ++ rightConditions)
84-
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond)))
84+
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _))
8585
if projectList.forall(_.isInstanceOf[Attribute]) =>
8686
extractInnerJoins(j)
8787
case _ =>
@@ -90,11 +90,11 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
9090
}
9191

9292
private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match {
93-
case j @ Join(left, right, jt: InnerLike, Some(cond)) =>
93+
case j @ Join(left, right, jt: InnerLike, Some(cond), _) =>
9494
val replacedLeft = replaceWithOrderedJoin(left)
9595
val replacedRight = replaceWithOrderedJoin(right)
9696
OrderedJoin(replacedLeft, replacedRight, jt, Some(cond))
97-
case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) =>
97+
case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) =>
9898
p.copy(child = replaceWithOrderedJoin(j))
9999
case _ =>
100100
plan

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
254254
// not allowed to use the same attributes. We use a blacklist to prevent us from creating a
255255
// situation in which this happens; the rule will only remove an alias if its child
256256
// attribute is not on the black list.
257-
case Join(left, right, joinType, condition) =>
257+
case Join(left, right, joinType, condition, notNullAttrs) =>
258258
val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet)
259259
val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet)
260260
val mapping = AttributeMap(
@@ -263,7 +263,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
263263
val newCondition = condition.map(_.transform {
264264
case a: Attribute => mapping.getOrElse(a, a)
265265
})
266-
Join(newLeft, newRight, joinType, newCondition)
266+
Join(newLeft, newRight, joinType, newCondition, notNullAttrs)
267267

268268
case _ =>
269269
// Remove redundant aliases in the subtree(s).
@@ -354,7 +354,7 @@ object LimitPushDown extends Rule[LogicalPlan] {
354354
// on both sides if it is applied multiple times. Therefore:
355355
// - If one side is already limited, stack another limit on top if the new limit is smaller.
356356
// The redundant limit will be collapsed by the CombineLimits rule.
357-
case LocalLimit(exp, join @ Join(left, right, joinType, _)) =>
357+
case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) =>
358358
val newJoin = joinType match {
359359
case RightOuter => join.copy(right = maybePushLocalLimit(exp, right))
360360
case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left))
@@ -468,7 +468,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
468468
p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices))
469469

470470
// Eliminate unneeded attributes from right side of a Left Existence Join.
471-
case j @ Join(_, right, LeftExistence(_), _) =>
471+
case j @ Join(_, right, LeftExistence(_), _, _) =>
472472
j.copy(right = prunedChild(right, j.references))
473473

474474
// all the columns will be used to compare, so we can't prune them
@@ -661,27 +661,38 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
661661
filter
662662
}
663663

664-
case join @ Join(left, right, joinType, conditionOpt) =>
664+
case join @ Join(left, right, joinType, conditionOpt, notNullAttrs) =>
665665
joinType match {
666666
// For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
667667
// inner join, it just drops the right side in the final output.
668668
case _: InnerLike | LeftSemi =>
669669
val allConstraints = getAllConstraints(left, right, conditionOpt)
670-
val newLeft = inferNewFilter(left, allConstraints)
671-
val newRight = inferNewFilter(right, allConstraints)
672-
join.copy(left = newLeft, right = newRight)
670+
val newLeftPredicates = inferNewPredicate(left, allConstraints)
671+
val newRightPredicates = inferNewPredicate(right, allConstraints)
672+
val newNotNullAttrs = getNotNullAttributes(
673+
newLeftPredicates ++ newRightPredicates, notNullAttrs)
674+
join.copy(
675+
left = addFilterIfNeeded(left, newLeftPredicates),
676+
right = addFilterIfNeeded(right, newRightPredicates),
677+
notNullAttributes = newNotNullAttrs)
673678

674679
// For right outer join, we can only infer additional filters for left side.
675680
case RightOuter =>
676681
val allConstraints = getAllConstraints(left, right, conditionOpt)
677-
val newLeft = inferNewFilter(left, allConstraints)
678-
join.copy(left = newLeft)
682+
val newLeftPredicates = inferNewPredicate(left, allConstraints)
683+
val newNotNullAttrs = getNotNullAttributes(newLeftPredicates, notNullAttrs)
684+
join.copy(
685+
left = addFilterIfNeeded(left, newLeftPredicates),
686+
notNullAttributes = newNotNullAttrs)
679687

680688
// For left join, we can only infer additional filters for right side.
681689
case LeftOuter | LeftAnti =>
682690
val allConstraints = getAllConstraints(left, right, conditionOpt)
683-
val newRight = inferNewFilter(right, allConstraints)
684-
join.copy(right = newRight)
691+
val newRightPredicates = inferNewPredicate(right, allConstraints)
692+
val newNotNullAttrs = getNotNullAttributes(newRightPredicates, notNullAttrs)
693+
join.copy(
694+
right = addFilterIfNeeded(right, newRightPredicates),
695+
notNullAttributes = newNotNullAttrs)
685696

686697
case _ => join
687698
}
@@ -696,16 +707,32 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
696707
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
697708
}
698709

699-
private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = {
700-
val newPredicates = constraints
710+
private def inferNewPredicate(
711+
plan: LogicalPlan, constraints: Set[Expression]): Set[Expression] = {
712+
constraints
701713
.union(constructIsNotNullConstraints(constraints, plan.output))
702714
.filter { c =>
703715
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
704716
} -- plan.constraints
705-
if (newPredicates.isEmpty) {
706-
plan
707-
} else {
717+
}
718+
719+
private def getNotNullAttributes(
720+
constraints: Set[Expression],
721+
curNotNullAttrs: Set[ExprId]): Set[ExprId] = {
722+
723+
// Split out all the IsNotNulls from the `constraints`
724+
val (notNullPreds, _) = constraints.partition {
725+
case IsNotNull(a) => isNullIntolerant(a)
726+
case _ => false
727+
}
728+
notNullPreds.flatMap(_.references.map(_.exprId)) ++ curNotNullAttrs
729+
}
730+
731+
private def addFilterIfNeeded(plan: LogicalPlan, newPredicates: Set[Expression]): LogicalPlan = {
732+
if (newPredicates.nonEmpty) {
708733
Filter(newPredicates.reduce(And), plan)
734+
} else {
735+
plan
709736
}
710737
}
711738
}
@@ -1048,7 +1075,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10481075

10491076
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
10501077
// push the where condition down into join filter
1051-
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
1078+
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, notNullAttrs)) =>
10521079
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
10531080
split(splitConjunctivePredicates(filterCondition), left, right)
10541081
joinType match {
@@ -1062,7 +1089,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10621089
commonFilterCondition.partition(canEvaluateWithinJoin)
10631090
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
10641091

1065-
val join = Join(newLeft, newRight, joinType, newJoinCond)
1092+
val join = Join(newLeft, newRight, joinType, newJoinCond, notNullAttrs)
10661093
if (others.nonEmpty) {
10671094
Filter(others.reduceLeft(And), join)
10681095
} else {
@@ -1074,7 +1101,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10741101
val newRight = rightFilterConditions.
10751102
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
10761103
val newJoinCond = joinCondition
1077-
val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond)
1104+
val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, notNullAttrs)
10781105

10791106
(leftFilterConditions ++ commonFilterCondition).
10801107
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
@@ -1084,7 +1111,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10841111
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
10851112
val newRight = right
10861113
val newJoinCond = joinCondition
1087-
val newJoin = Join(newLeft, newRight, joinType, newJoinCond)
1114+
val newJoin = Join(newLeft, newRight, joinType, newJoinCond, notNullAttrs)
10881115

10891116
(rightFilterConditions ++ commonFilterCondition).
10901117
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
@@ -1094,7 +1121,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10941121
}
10951122

10961123
// push down the join filter into sub query scanning if applicable
1097-
case j @ Join(left, right, joinType, joinCondition) =>
1124+
case j @ Join(left, right, joinType, joinCondition, notNullAttrs) =>
10981125
val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
10991126
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
11001127

@@ -1107,23 +1134,23 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
11071134
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
11081135
val newJoinCond = commonJoinCondition.reduceLeftOption(And)
11091136

1110-
Join(newLeft, newRight, joinType, newJoinCond)
1137+
Join(newLeft, newRight, joinType, newJoinCond, notNullAttrs)
11111138
case RightOuter =>
11121139
// push down the left side only join filter for left side sub query
11131140
val newLeft = leftJoinConditions.
11141141
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
11151142
val newRight = right
11161143
val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
11171144

1118-
Join(newLeft, newRight, RightOuter, newJoinCond)
1145+
Join(newLeft, newRight, RightOuter, newJoinCond, notNullAttrs)
11191146
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
11201147
// push down the right side only join filter for right sub query
11211148
val newLeft = left
11221149
val newRight = rightJoinConditions.
11231150
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
11241151
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
11251152

1126-
Join(newLeft, newRight, joinType, newJoinCond)
1153+
Join(newLeft, newRight, joinType, newJoinCond, notNullAttrs)
11271154
case FullOuter => j
11281155
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
11291156
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
@@ -1179,7 +1206,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper {
11791206
if (SQLConf.get.crossJoinEnabled) {
11801207
plan
11811208
} else plan transform {
1182-
case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _)
1209+
case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _)
11831210
if isCartesianProduct(j) =>
11841211
throw new AnalysisException(
11851212
s"""Detected cartesian product for ${j.joinType.sql} join between logical plans

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit
5656
// Joins on empty LocalRelations generated from streaming sources are not eliminated
5757
// as stateful streaming joins need to perform other state management operations other than
5858
// just processing the input data.
59-
case p @ Join(_, _, joinType, _)
59+
case p @ Join(_, _, joinType, _, _)
6060
if !p.children.exists(_.isStreaming) =>
6161
val isLeftEmpty = isEmptyLocalRelation(p.left)
6262
val isRightEmpty = isEmptyLocalRelation(p.right)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
544544
// propagating the foldable expressions.
545545
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
546546
// of outer join.
547-
case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty =>
547+
case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty =>
548548
val newJoin = j.transformExpressions(replaceFoldable)
549549
val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
550550
case _: InnerLike | LeftExistence(_) => Nil

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
147147
}
148148

149149
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
150-
case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) =>
150+
case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _, _)) =>
151151
val newJoinType = buildNewJoinType(f, j)
152152
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType))
153153
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
5454
// the produced join then becomes unresolved and break structural integrity. We should
5555
// de-duplicate conflicting attributes. We don't use transformation here because we only
5656
// care about the most top join converted from correlated predicate subquery.
57-
case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) =>
57+
case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond, _) =>
5858
val duplicates = right.outputSet.intersect(left.outputSet)
5959
if (duplicates.nonEmpty) {
6060
val aliasMap = AttributeMap(duplicates.map { dup =>

0 commit comments

Comments
 (0)