Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class ConstraintPropagationSuite extends SparkFunSuite {
private def resolveColumn(plan: LogicalPlan, columnName: String): Expression =
plan.resolveQuoted(columnName, caseInsensitiveResolution).get

private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = {
val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _))
val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _))
private def verifyConstraints(found: ExpressionSet, expected: ExpressionSet): Unit = {
val missing = expected -- found
val extra = found -- expected
if (missing.nonEmpty || extra.nonEmpty) {
fail(
s"""
Expand All @@ -58,18 +58,18 @@ class ConstraintPropagationSuite extends SparkFunSuite {
verifyConstraints(tr
.where('a.attr > 10)
.analyze.constraints,
Set(resolveColumn(tr, "a") > 10,
IsNotNull(resolveColumn(tr, "a"))))
ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
IsNotNull(resolveColumn(tr, "a")))))

verifyConstraints(tr
.where('a.attr > 10)
.select('c.attr, 'a.attr)
.where('c.attr < 100)
.analyze.constraints,
Set(resolveColumn(tr, "a") > 10,
ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
resolveColumn(tr, "c") < 100,
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "c"))))
IsNotNull(resolveColumn(tr, "c")))))
}

test("propagating constraints in aggregate") {
Expand All @@ -81,10 +81,10 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze

verifyConstraints(aliasedRelation.analyze.constraints,
Set(resolveColumn(aliasedRelation.analyze, "c1") > 10,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "c1") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")),
resolveColumn(aliasedRelation.analyze, "a") < 5,
IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))
IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))))
}

test("propagating constraints in aliases") {
Expand All @@ -95,11 +95,11 @@ class ConstraintPropagationSuite extends SparkFunSuite {
val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z))

verifyConstraints(aliasedRelation.analyze.constraints,
Set(resolveColumn(aliasedRelation.analyze, "x") > 10,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
resolveColumn(aliasedRelation.analyze, "z") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))
}

test("propagating constraints in union") {
Expand All @@ -118,8 +118,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10)))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a"))))
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
}

test("propagating constraints in intersect") {
Expand All @@ -130,10 +130,10 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.intersect(tr2.where('b.attr < 100))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
resolveColumn(tr1, "b") < 100,
IsNotNull(resolveColumn(tr1, "a")),
IsNotNull(resolveColumn(tr1, "b"))))
IsNotNull(resolveColumn(tr1, "b")))))
}

test("propagating constraints in except") {
Expand All @@ -143,8 +143,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.except(tr2.where('b.attr < 100))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a"))))
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
}

test("propagating constraints in inner join") {
Expand All @@ -154,13 +154,13 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
tr1.resolveQuoted("a", caseInsensitiveResolution).get ===
tr2.resolveQuoted("a", caseInsensitiveResolution).get,
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
}

test("propagating constraints in left-semi join") {
Expand All @@ -170,8 +170,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))))
}

test("propagating constraints in left-outer join") {
Expand All @@ -181,8 +181,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))))
}

test("propagating constraints in right-outer join") {
Expand All @@ -192,8 +192,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
}

test("propagating constraints in full-outer join") {
Expand Down