Skip to content

Commit 7fb2f9c

Browse files
committed
join propagation
1 parent 0c4c78b commit 7fb2f9c

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,27 +97,24 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
9797
}
9898

9999
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
100+
final override lazy val resolved: Boolean =
101+
childrenResolved &&
102+
left.output.length == right.output.length &&
103+
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
100104

101-
override def output: Seq[Attribute] =
102-
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
103-
leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
104-
}
105+
override def extractConstraintsFromChild(child: QueryPlan[LogicalPlan]): Set[Expression] = {
106+
child.constraints.filter(_.references.subsetOf(child.outputSet))
107+
}
105108

106109
protected def leftConstraints: Set[Expression] = extractConstraintsFromChild(left)
107110

108111
protected def rightConstraints: Set[Expression] = {
109112
require(left.output.size == right.output.size)
110113
val attributeRewrites = AttributeMap(right.output.zip(left.output))
111-
println(extractConstraintsFromChild(right), attributeRewrites)
112114
extractConstraintsFromChild(right).map(_ transform {
113115
case a: Attribute => attributeRewrites(a)
114116
})
115117
}
116-
117-
final override lazy val resolved: Boolean =
118-
childrenResolved &&
119-
left.output.length == right.output.length &&
120-
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
121118
}
122119

123120
private[sql] object SetOperation {
@@ -176,6 +173,10 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
176173
Statistics(sizeInBytes = sizeInBytes)
177174
}
178175

176+
override def extractConstraintsFromChild(child: QueryPlan[LogicalPlan]): Set[Expression] = {
177+
child.constraints.filter(_.references.subsetOf(child.outputSet))
178+
}
179+
179180
def rewriteConstraints(
180181
planA: LogicalPlan,
181182
planB: LogicalPlan,
@@ -214,19 +215,43 @@ case class Join(
214215
}
215216
}
216217

218+
def extractNullabilityConstraintsFromJoinCondition(): Set[Expression] = {
219+
var constraints = Set.empty[Expression]
220+
if (condition.isDefined) {
221+
splitConjunctivePredicates(condition.get).foreach {
222+
case EqualTo(l, r) =>
223+
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
224+
case GreaterThan(l, r) =>
225+
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
226+
case GreaterThanOrEqual(l, r) =>
227+
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
228+
case LessThan(l, r) =>
229+
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
230+
case LessThanOrEqual(l, r) =>
231+
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
232+
}
233+
}
234+
// Currently we only propagate constraints if the condition consists of equality
235+
// and ranges. For all other cases, we return an empty set of constraints
236+
constraints
237+
}
238+
217239
override def constraints: Set[Expression] = {
218240
joinType match {
219241
case Inner =>
242+
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
243+
.union(extractNullabilityConstraintsFromJoinCondition())
220244
case LeftSemi =>
221245
extractConstraintsFromChild(left)
246+
.union(extractNullabilityConstraintsFromJoinCondition())
222247
case LeftOuter =>
223248
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
224249
case RightOuter =>
225250
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
226251
case FullOuter =>
227252
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
228253
case _ =>
229-
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
254+
Set.empty
230255
}
231256
}
232257

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
2929
private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
3030
tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
3131

32-
test("propagating constraints in filter") {
32+
test("propagating constraints in filter/project") {
3333
val tr = LocalRelation('a.int, 'b.string, 'c.int)
3434
assert(tr.analyze.constraints.isEmpty)
3535
assert(tr.select('a.attr).analyze.constraints.isEmpty)
@@ -40,10 +40,26 @@ class ConstraintPropagationSuite extends SparkFunSuite {
4040
}
4141

4242
test("propagating constraints in union") {
43-
val tr1 = LocalRelation('a.int, 'b.string, 'c.int)
44-
val tr2 = LocalRelation('a.int, 'b.string, 'c.int)
45-
assert(tr1.analyze.constraints.isEmpty && tr2.analyze.constraints.isEmpty)
46-
assert(tr1.where('a.attr > 10).unionAll(tr2.where('a.attr > 10))
47-
.analyze.constraints == Set(resolveColumn(tr1, "a") > 10))
43+
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
44+
val tr2 = LocalRelation('d.int, 'e.int, 'f.int)
45+
val tr3 = LocalRelation('g.int, 'h.int, 'i.int)
46+
assert(tr1.where('a.attr > 10).unionAll(tr2.where('e.attr > 10)
47+
.unionAll(tr3.where('i.attr > 10))).analyze.constraints.isEmpty)
48+
assert(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10)
49+
.unionAll(tr3.where('g.attr > 10))).analyze.constraints == Set(resolveColumn(tr1, "a") > 10))
50+
}
51+
52+
test("propagating constraints in intersect") {
53+
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
54+
val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
55+
assert(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100)).analyze.constraints ==
56+
Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100))
57+
}
58+
59+
test("propagating constraints in except") {
60+
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
61+
val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
62+
assert(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints ==
63+
Set(resolveColumn(tr1, "a") > 10))
4864
}
4965
}

0 commit comments

Comments
 (0)