@@ -97,27 +97,24 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
9797}
9898
9999abstract class SetOperation (left : LogicalPlan , right : LogicalPlan ) extends BinaryNode {
100+ final override lazy val resolved : Boolean =
101+ childrenResolved &&
102+ left.output.length == right.output.length &&
103+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
100104
101- override def output : Seq [Attribute ] =
102- left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
103- leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
104- }
105+ override def extractConstraintsFromChild (child : QueryPlan [LogicalPlan ]): Set [Expression ] = {
106+ child.constraints.filter(_.references.subsetOf(child.outputSet))
107+ }
105108
106109 protected def leftConstraints : Set [Expression ] = extractConstraintsFromChild(left)
107110
108111 protected def rightConstraints : Set [Expression ] = {
109112 require(left.output.size == right.output.size)
110113 val attributeRewrites = AttributeMap (right.output.zip(left.output))
111- println(extractConstraintsFromChild(right), attributeRewrites)
112114 extractConstraintsFromChild(right).map(_ transform {
113115 case a : Attribute => attributeRewrites(a)
114116 })
115117 }
116-
117- final override lazy val resolved : Boolean =
118- childrenResolved &&
119- left.output.length == right.output.length &&
120- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
121118}
122119
123120private [sql] object SetOperation {
@@ -176,6 +173,10 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
176173 Statistics (sizeInBytes = sizeInBytes)
177174 }
178175
176+ override def extractConstraintsFromChild (child : QueryPlan [LogicalPlan ]): Set [Expression ] = {
177+ child.constraints.filter(_.references.subsetOf(child.outputSet))
178+ }
179+
179180 def rewriteConstraints (
180181 planA : LogicalPlan ,
181182 planB : LogicalPlan ,
@@ -214,19 +215,43 @@ case class Join(
214215 }
215216 }
216217
218+ def extractNullabilityConstraintsFromJoinCondition (): Set [Expression ] = {
219+ var constraints = Set .empty[Expression ]
220+ if (condition.isDefined) {
221+ splitConjunctivePredicates(condition.get).foreach {
222+ case EqualTo (l, r) =>
223+ constraints = constraints.union(Set (IsNotNull (l), IsNotNull (r)))
224+ case GreaterThan (l, r) =>
225+ constraints = constraints.union(Set (IsNotNull (l), IsNotNull (r)))
226+ case GreaterThanOrEqual (l, r) =>
227+ constraints = constraints.union(Set (IsNotNull (l), IsNotNull (r)))
228+ case LessThan (l, r) =>
229+ constraints = constraints.union(Set (IsNotNull (l), IsNotNull (r)))
230+ case LessThanOrEqual (l, r) =>
231+ constraints = constraints.union(Set (IsNotNull (l), IsNotNull (r)))
232+ }
233+ }
234+ // Currently we only propagate constraints if the condition consists of equality
235+ // and ranges. For all other cases, we return an empty set of constraints
236+ constraints
237+ }
238+
217239 override def constraints : Set [Expression ] = {
218240 joinType match {
219241 case Inner =>
242+ extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
243+ .union(extractNullabilityConstraintsFromJoinCondition())
220244 case LeftSemi =>
221245 extractConstraintsFromChild(left)
246+ .union(extractNullabilityConstraintsFromJoinCondition())
222247 case LeftOuter =>
223248 extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
224249 case RightOuter =>
225250 extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
226251 case FullOuter =>
227252 extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
228253 case _ =>
229- extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
254+ Set .empty
230255 }
231256 }
232257
0 commit comments