Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improve test.
  • Loading branch information
viirya committed Oct 7, 2016
commit c0637b26808aed386c4d937ebca44958e9f89c09
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ object Canonicalize extends {
case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add)
case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)

case o: Or => orderCommutative(o, { case Or(l, r) => Seq(l, r) }).reduce(Or)
case a: And => orderCommutative(a, { case And(l, r) => Seq(l, r)}).reduce(And)

case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l)
case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ case class CNFNormalization(conf: CatalystConf) extends Rule[LogicalPlan] with P
}
finalPredicates = predicates.toSeq
}
val cnf = finalPredicates.map(toCNF(_, depth + 1))
val cnf = finalPredicates.map { p =>
if (p.semanticEquals(predicate)) {
p
} else {
toCNF(p, depth + 1)
}
}
if (depth == 0 && cnf.length > conf.maxPredicateNumberForCNFNormalization) {
return predicate
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,14 @@ class CNFNormalizationSuite extends SparkFunSuite with PredicateHelper {

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int, 'e.int)

// Change the predicate orders in [[And]] and [[Or]] so we can compare them consistently.
private def normalizationPredicate(predicate: Expression): Expression = {
predicate transformUp {
case Or(a, b) =>
if (a.hashCode() > b.hashCode) {
Or(b, a)
} else {
Or(a, b)
}
case And(a, b) =>
if (a.hashCode() > b.hashCode) {
And(b, a)
} else {
And(a, b)
}
}
}

private def checkCondition(input: Expression, expected: Expression): Unit = {
val actual = Optimize.execute(testRelation.where(input).analyze)
val correctAnswer = Optimize.execute(testRelation.where(expected).analyze)

val resultFilterExpression = actual.collectFirst { case f: Filter => f.condition }.get
val expectedFilterExpression = correctAnswer.collectFirst { case f: Filter => f.condition }.get

val normalizedResult = splitConjunctivePredicates(resultFilterExpression)
.map(normalizationPredicate).sortBy(_.toString)
val normalizedExpected = splitConjunctivePredicates(expectedFilterExpression)
.map(normalizationPredicate).sortBy(_.toString)

assert(normalizedResult == normalizedExpected)
assert(resultFilterExpression.semanticEquals(expectedFilterExpression))
}

private val a = Literal(1) < 'a
Expand Down Expand Up @@ -163,6 +140,16 @@ class CNFNormalizationSuite extends SparkFunSuite with PredicateHelper {
checkCondition(input, expected)
}

test("((a && b) || (c && d)) || e") {
val input = ((a && b) || (c && d)) || e
val expected = ((a || c) || e) && ((a || d) || e) && ((b || c) || e) && ((b || d) || e)
checkCondition(input, expected)
val analyzed = testRelation.where(input).analyze
val optimized = Optimize.execute(analyzed)
val resultFilterExpression = optimized.collectFirst { case f: Filter => f.condition }.get
println(s"resultFilterExpression: $resultFilterExpression")
}

test("CNF normalization exceeds max predicate numbers") {
val input = (1 to 100).map(i => Literal(i) < 'c).reduce(And) ||
(1 to 10).map(i => Literal(i) < 'a).reduce(And)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ class ExpressionSetSuite extends SparkFunSuite {
setTest(1, Not(aUpper >= 1), aUpper < 1, Not(Literal(1) <= aUpper), Literal(1) > aUpper)
setTest(1, Not(aUpper <= 1), aUpper > 1, Not(Literal(1) >= aUpper), Literal(1) < aUpper)

setTest(1, aUpper > bUpper && aUpper <= 10, aUpper <= 10 && aUpper > bUpper)
setTest(1,
aUpper > bUpper &&
bUpper > 100 &&
aUpper <= 10,
bUpper > 100 &&
aUpper <= 10 &&
aUpper > bUpper)

setTest(1, aUpper > bUpper || aUpper <= 10, aUpper <= 10 || aUpper > bUpper)
setTest(1,
aUpper > bUpper ||
bUpper > 100 ||
aUpper <= 10,
bUpper > 100 ||
aUpper <= 10 ||
aUpper > bUpper)

test("add to / remove from set") {
val initialSet = ExpressionSet(aUpper + 1 :: Nil)

Expand Down