Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Convert filter predicate to CNF in Optimizer.
  • Loading branch information
viirya committed Oct 6, 2016
commit baac6327b5a9c1a234e34da538a72d8ef87a9e35
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ trait CatalystConf {
def optimizerInSetConversionThreshold: Int
def maxCaseBranchesForCodegen: Int

def maxDepthForCNFNormalization: Int
def maxPredicateNumberForCNFNormalization: Int

def runSQLonFile: Boolean

def warehousePath: String
Expand Down Expand Up @@ -60,6 +63,8 @@ case class SimpleCatalystConf(
optimizerMaxIterations: Int = 100,
optimizerInSetConversionThreshold: Int = 10,
maxCaseBranchesForCodegen: Int = 20,
maxDepthForCNFNormalization: Int = 10,
maxPredicateNumberForCNFNormalization: Int = 20,
runSQLonFile: Boolean = true,
crossJoinEnabled: Boolean = false,
warehousePath: String = "/user/hive/warehouse")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
ReorderAssociativeOperator,
LikeSimplification,
BooleanSimplification,
CNFNormalization(conf),
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.analysis._
Expand Down Expand Up @@ -132,6 +133,35 @@ case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] {
}
}

/**
* Convert the predicates of [[Filter]] operators to CNF form.
*/
case class CNFNormalization(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper {
private def toCNF(predicate: Expression, depth: Int = 0): Expression = {
if (depth > conf.maxDepthForCNFNormalization) {
return predicate
}
val disjunctives = splitDisjunctivePredicates(predicate)
var finalPredicates = splitConjunctivePredicates(disjunctives.head)
disjunctives.tail.foreach { cond =>
val predicates = new ArrayBuffer[Expression]()
splitConjunctivePredicates(cond).map { p =>
predicates ++= finalPredicates.map(Or(_, p))
}
finalPredicates = predicates.toSeq
}
val cnf = finalPredicates.map(toCNF(_, depth + 1))
if (depth == 0 && cnf.length > conf.maxPredicateNumberForCNFNormalization) {
return predicate
} else {
cnf.reduce(And)
}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(condition, _) => f.copy(condition = toCNF(condition))
}
}

/**
* Simplifies boolean expressions:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.SimpleCatalystConf

class CNFNormalizationSuite extends SparkFunSuite with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation,
ConstantFolding,
BooleanSimplification,
CNFNormalization(SimpleCatalystConf(true)),
PruneFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int, 'e.int)

// Change the predicate orders in [[And]] and [[Or]] so we can compare them consistently.
private def normalizationPredicate(predicate: Expression): Expression = {
predicate transformUp {
case Or(a, b) =>
if (a.hashCode() > b.hashCode) {
Or(b, a)
} else {
Or(a, b)
}
case And(a, b) =>
if (a.hashCode() > b.hashCode) {
And(b, a)
} else {
And(a, b)
}
}
}

private def checkCondition(input: Expression, expected: Expression): Unit = {
val actual = Optimize.execute(testRelation.where(input).analyze)
val correctAnswer = Optimize.execute(testRelation.where(expected).analyze)

val resultFilterExpression = actual.collectFirst { case f: Filter => f.condition }.get
val expectedFilterExpression = correctAnswer.collectFirst { case f: Filter => f.condition }.get

val normalizedResult = splitConjunctivePredicates(resultFilterExpression)
.map(normalizationPredicate).sortBy(_.toString)
val normalizedExpected = splitConjunctivePredicates(expectedFilterExpression)
.map(normalizationPredicate).sortBy(_.toString)

assert(normalizedResult == normalizedExpected)
}

private val a = Literal(1) < 'a
private val b = Literal(1) < 'b
private val c = Literal(1) < 'c
private val d = Literal(1) < 'd
private val e = Literal(1) < 'e
private val f = ! a

test("a || b => a || b") {
checkCondition(a || b, a || b)
}

test("a && b && c => a && b && c") {
checkCondition(a && b && c, a && b && c)
}

test("a && !(b || c) => a && !b && !c") {
checkCondition(a && !(b || c), a && !b && !c)
}

test("a && b || c => (a || c) && (b || c)") {
checkCondition(a && b || c, (a || c) && (b || c))
}

test("a && b || f => (a || f) && (b || f)") {
checkCondition(a && b || f, b || f)
}

test("(a && b) || (c && d) => (c || a) && (c || b) && ((d || a) && (d || b))") {
checkCondition((a && b) || (c && d), (a || c) && (b || c) && (a || d) && (b || d))
}

test("(a && b) || !(c && d) => (a || !c || !d) && (b || !c || !d)") {
checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d))
}

test("a || b || c && d => (a || b || c) && (a || b || d)") {
checkCondition(a || b || c && d, (a || b || c) && (a || b || d))
}

test("a || (b && c || d) => (a || b || d) && (a || c || d)") {
checkCondition(a || (b && c || d), (a || b || d) && (a || c || d))
}

test("a || !(b && c || d) => (a || !b || !c) && (a || !d)") {
checkCondition(a || !(b && c || d), (a || !b || !c) && (a || !d))
}

test("a && (b && c || d && e) => a && (b || d) && (c || d) && (b || e) && (c || e)") {
val input = a && (b && c || d && e)
val expected = a && (b || d) && (c || d) && (b || e) && (c || e)
checkCondition(input, expected)
}

test("a && !(b && c || d && e) => a && (!b || !c) && (!d || !e)") {
checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e))
}

test(
"a || (b && c || d && e) => (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e)") {
val input = a || (b && c || d && e)
val expected = (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e)
checkCondition(input, expected)
}

test(
"a || !(b && c || d && e) => (a || !b || !c) && (a || !d || !e)") {
checkCondition(a || !(b && c || d && e), (a || !b || !c) && (a || !d || !e))
}

test("a && b && c || !(d && e) => (a || !d || !e) && (b || !d || !e) && (c || !d || !e)") {
val input = a && b && c || !(d && e)
val expected = (a || !d || !e) && (b || !d || !e) && (c || !d || !e)
checkCondition(input, expected)
}

test(
"a && b && c || d && e && f => " +
"(a || d) && (a || e) && (a || f) && (b || d) && " +
"(b || e) && (b || f) && (c || d) && (c || e) && (c || f)") {
val input = (a && b && c) || (d && e && f)
val expected = (a || d) && (a || e) && (a || f) &&
(b || d) && (b || e) && (b || f) &&
(c || d) && (c || e) && (c || f)
checkCondition(input, expected)
}

test("CNF normalization exceeds max predicate numbers") {
val input = (1 to 100).map(i => Literal(i) < 'c).reduce(And) ||
(1 to 10).map(i => Literal(i) < 'a).reduce(And)
val analyzed = testRelation.where(input).analyze
val optimized = Optimize.execute(analyzed)
val resultFilterExpression = optimized.collectFirst { case f: Filter => f.condition }.get
val expectedFilterExpression = analyzed.collectFirst { case f: Filter => f.condition }.get
assert(resultFilterExpression.semanticEquals(expectedFilterExpression))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.types.IntegerType

class FilterPushdownSuite extends PlanTest {
Expand All @@ -37,6 +38,7 @@ class FilterPushdownSuite extends PlanTest {
CombineFilters,
PushDownPredicate,
BooleanSimplification,
CNFNormalization(SimpleCatalystConf(true)),
PushPredicateThroughJoin,
CollapseProject) :: Nil
}
Expand Down Expand Up @@ -1018,4 +1020,46 @@ class FilterPushdownSuite extends PlanTest {

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("push down filters that are not be able to pushed down after simplification") {
// The following predicate ('a === 2 || 'a === 3) && ('c > 10 || 'a === 2)
// will be simplified as ('a == 2) || ('c > 10 && 'a == 3).
// In its original form, ('a === 2 || 'a === 3) can be pushed down.
// But the simplified one can't.
val originalQuery = testRelation
.select('a, 'b, ('c + 1) as 'cc)
.groupBy('a)('a, count('cc) as 'c)
.where('c > 10) // this predicate can't be pushed down.
.where(('a === 2 || 'a === 3) && ('c > 10 || 'a === 2))

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where('a === 2 || 'a === 3)
.select('a, 'b, ('c + 1) as 'cc)
.groupBy('a)('a, count('cc) as 'c)
.where('c > 10).analyze

comparePlans(optimized, correctAnswer)
}

test("disjunctive predicates which are able to pushdown should be pushed down after converted") {
// (('a === 2) || ('c > 10 || 'a === 3)) can't be pushdown due to the disjunctive form.
// However, its conjunctive normal form can be pushdown.
val originalQuery = testRelation
.select('a, 'b, ('c + 1) as 'cc)
.groupBy('a)('a, count('cc) as 'c)
.where('c > 10)
.where(('a === 2) || ('c > 10 && 'a === 3))

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where('a === 2 || 'a === 3)
.select('a, 'b, ('c + 1) as 'cc)
.groupBy('a)('a, count('cc) as 'c)
.where('c > 10).analyze

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,18 @@ object SQLConf {
.intConf
.createWithDefault(20)

val MAX_DEPTH_CNF_PREDICATE = SQLConfigBuilder("spark.sql.expression.cnf.maxDepth")
.internal()
.doc("The maximum depth of converting recursively filter predicates to CNF normalization.")
.intConf
.createWithDefault(10)

val MAX_PREDICATE_NUMBER_CNF_PREDICATE = SQLConfigBuilder("spark.sql.expression.cnf.maxNumber")
.internal()
.doc("The maximum number of predicates in the CNF normalization of filter predicates")
.intConf
.createWithDefault(20)

val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
Expand Down Expand Up @@ -685,6 +697,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)

def maxDepthForCNFNormalization: Int = getConf(MAX_DEPTH_CNF_PREDICATE)

def maxPredicateNumberForCNFNormalization: Int = getConf(MAX_PREDICATE_NUMBER_CNF_PREDICATE)

def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)

def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
Expand Down