Skip to content
Closed
Prev Previous commit
Next Next commit
support all joins
  • Loading branch information
sameeragarwal committed Jan 27, 2016
commit f15ef96657603b79a815853fba991835fe3ca50f
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,13 @@ case class Join(
}
}

def extractNullabilityConstraintsFromJoinCondition(): Set[Expression] = {
override def constraints: Set[Expression] = {
var constraints = Set.empty[Expression]
if (condition.isDefined) {
splitConjunctivePredicates(condition.get).foreach {

// Currently we only propagate constraints if the condition consists of equality
// and ranges. For all other cases, we return an empty set of constraints
def extractIsNotNullConstraints(condition: Expression): Set[Expression] = {
splitConjunctivePredicates(condition).foreach {
case EqualTo(l, r) =>
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
case GreaterThan(l, r) =>
Expand All @@ -230,29 +233,32 @@ case class Join(
case LessThanOrEqual(l, r) =>
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
}
constraints
}
// Currently we only propagate constraints if the condition consists of equality
// and ranges. For all other cases, we return an empty set of constraints
constraints
}

override def constraints: Set[Expression] = {
joinType match {
case Inner =>
def extractIsNullConstraints(plan: LogicalPlan): Set[Expression] = {
constraints = constraints.union(plan.output.map(IsNull).toSet)
constraints
}

constraints = joinType match {
case Inner if condition.isDefined =>
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
.union(extractNullabilityConstraintsFromJoinCondition())
case LeftSemi =>
extractConstraintsFromChild(left)
.union(extractNullabilityConstraintsFromJoinCondition())
case LeftOuter =>
.union(extractIsNotNullConstraints(condition.get))
case LeftSemi if condition.isDefined =>
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
.union(extractIsNotNullConstraints(condition.get))
case LeftOuter =>
extractConstraintsFromChild(left).union(extractIsNullConstraints(right))
case RightOuter =>
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
extractConstraintsFromChild(right).union(extractIsNullConstraints(left))
case FullOuter =>
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
extractIsNullConstraints(left).union(extractIsNullConstraints(right))
case _ =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What types of joins are we not handling? It might be better to get an exception if we add a new type and its not handled, but I'm not sure.

Set.empty
}

constraints.filter(_.references.subsetOf(outputSet))
}

def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a merging mistake as its duplicated with the method below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, fixed!

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,19 @@ class ConstraintPropagationSuite extends SparkFunSuite {
private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get

private def verifyConstraints(a: Set[Expression], b: Set[Expression]): Unit = {
assert(a.forall(i => b.map(_.semanticEquals(i)).reduce(_ || _)))
assert(b.forall(i => a.map(_.semanticEquals(i)).reduce(_ || _)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make this function manually call fail with the condition that we can't find, and also differentiate between missing and found but not expected.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, good idea.

}

test("propagating constraints in filter/project") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
assert(tr.analyze.constraints.isEmpty)
assert(tr.select('a.attr).analyze.constraints.isEmpty)
assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn(tr, "a") > 10))
assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
.analyze.constraints == Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100))
verifyConstraints(tr.where('a.attr > 10).analyze.constraints, Set(resolveColumn(tr, "a") > 10))
verifyConstraints(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
.analyze.constraints, Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100))
}

test("propagating constraints in union") {
Expand All @@ -45,21 +50,76 @@ class ConstraintPropagationSuite extends SparkFunSuite {
val tr3 = LocalRelation('g.int, 'h.int, 'i.int)
assert(tr1.where('a.attr > 10).unionAll(tr2.where('e.attr > 10)
.unionAll(tr3.where('i.attr > 10))).analyze.constraints.isEmpty)
assert(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10))).analyze.constraints == Set(resolveColumn(tr1, "a") > 10))
verifyConstraints(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10))).analyze.constraints, Set(resolveColumn(tr1, "a") > 10))
}

test("propagating constraints in intersect") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
assert(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100)).analyze.constraints ==
Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100))
verifyConstraints(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100))
.analyze.constraints, Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100))
}

test("propagating constraints in except") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
assert(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints ==
verifyConstraints(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints,
Set(resolveColumn(tr1, "a") > 10))
}

test("propagating constraints in inner join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), Inner,
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
}

test("propagating constraints in left-semi join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), LeftSemi,
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
}

test("propagating constraints in left-outer join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), LeftOuter,
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get),
IsNull(tr2.resolveQuoted("e", caseInsensitiveResolution).get)))
}

test("propagating constraints in right-outer join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), RightOuter,
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get),
IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get)))
}

test("propagating constraints in full-outer join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), FullOuter,
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
Set(IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get),
IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get),
IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get),
IsNull(tr2.resolveQuoted("e", caseInsensitiveResolution).get)))
}
}