Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix join filter inference from constraints
  • Loading branch information
cloud-fan committed Apr 17, 2018
commit 561db44b6a42a4a58996e58dfd9555bf45006e9f
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,8 @@ object CollapseWindow extends Rule[LogicalPlan] {
* constraints. These filters are currently inserted to the existing conditions in the Filter
* operators and on either side of Join operators.
*
* Note: While this optimization is applicable to all types of join, it primarily benefits Inner and
* LeftSemi joins.
* Note: While this optimization is applicable to a lot of types of join, it primarily benefits
* Inner and LeftSemi joins.
*/
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper {

Expand All @@ -661,21 +661,51 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
}

case join @ Join(left, right, joinType, conditionOpt) =>
// Only consider constraints that can be pushed down completely to either the left or the
// right child
val constraints = join.constraints.filter { c =>
c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)
}
// Remove those constraints that are already enforced by either the left or the right child
val additionalConstraints = constraints -- (left.constraints ++ right.constraints)
val newConditionOpt = conditionOpt match {
case Some(condition) =>
val newFilters = additionalConstraints -- splitConjunctivePredicates(condition)
if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else None
case None =>
additionalConstraints.reduceOption(And)
joinType match {
// For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
// inner join, it just drops the right side in the final output.
case _: InnerLike | LeftSemi =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newLeft = inferNewFilter(left, allConstraints)
val newRight = inferNewFilter(right, allConstraints)
join.copy(left = newLeft, right = newRight)

// For right outer join, we can only infer additional filters for left side.
case RightOuter =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newLeft = inferNewFilter(left, allConstraints)
join.copy(left = newLeft)

// For left join, we can only infer additional filters for right side.
case LeftOuter | LeftAnti =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newRight = inferNewFilter(right, allConstraints)
join.copy(right = newRight)

case _ => join
}
if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join
}

private def getAllConstraints(
left: LogicalPlan,
right: LogicalPlan,
conditionOpt: Option[Expression]): Set[Expression] = {
val baseConstraints = left.constraints.union(right.constraints)
.union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet)
baseConstraints.union(ConstraintsUtils.inferAdditionalConstraints(baseConstraints))
}

private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may we return here the additional constraints instead of the new plan, so that in L669 and similar we can copy the plan only if it is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copying the plan is very cheap.

val newPredicates = constraints
.union(ConstraintsUtils.constructIsNotNullConstraints(constraints, plan.output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
} -- plan.constraints
if (newPredicates.isEmpty) {
plan
} else {
Filter(newPredicates.reduce(And), plan)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ trait QueryPlanConstraints { self: LogicalPlan =>
if (conf.constraintPropagationEnabled) {
ExpressionSet(
validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints))
.union(ConstraintsUtils.inferAdditionalConstraints(validConstraints))
.union(ConstraintsUtils.constructIsNotNullConstraints(validConstraints, output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
}
Expand All @@ -51,13 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* See [[Canonicalize]] for more details.
*/
protected def validConstraints: Set[Expression] = Set.empty
}

object ConstraintsUtils {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in order to follow the pattern I see also for PredicateHelper, what about having a ConstraintHelper trait instead of this object?


/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferredConstraints = Set.empty[Expression]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = constraints - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case _ => // No inference
}
inferredConstraints -- constraints
}

private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
destination: Attribute): Set[Expression] = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})

/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
* returns a constraint of the form `isNotNull(a)`
*/
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
def constructIsNotNullConstraints(
constraints: Set[Expression],
output: Seq[Attribute]): Set[Expression] = {
// First, we propagate constraints from the null intolerant expressions.
var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)

Expand Down Expand Up @@ -93,28 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan =>
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferredConstraints = Set.empty[Expression]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = constraints - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case _ => // No inference
}
inferredConstraints -- constraints
}

private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
destination: Attribute): Set[Expression] = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
InferFiltersFromConstraints,
CombineFilters,
SimplifyBinaryComparison,
BooleanSimplification) :: Nil
BooleanSimplification,
PruneFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

private def testConstraintsAfterJoin(
x: LogicalPlan,
y: LogicalPlan,
expectedLeft: LogicalPlan,
expectedRight: LogicalPlan,
joinType: JoinType) = {
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, joinType, condition).analyze
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("filter: filter out constraints in condition") {
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val correctAnswer = testRelation
Expand Down Expand Up @@ -192,4 +206,61 @@ class InferFiltersFromConstraintsSuite extends PlanTest {

comparePlans(Optimize.execute(original.analyze), correct.analyze)
}

test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
}

test("SPARK-23564: left anti join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti)
}

test("SPARK-23564: left outer join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter)
}

test("SPARK-23564: right outer join should filter out null join keys on left side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
}

test("SPARK-21479: Outer join after-join filters push down to null-supplying side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze
val left = x.where(IsNotNull('a) && 'a === 2)
val right = y.where(IsNotNull('a) && 'a === 2)
val correctAnswer = left.join(right, LeftOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze
val left = x.where(IsNotNull('a) && 'a > 5)
val right = y.where(IsNotNull('a) && 'a > 5)
val correctAnswer = left.join(right, RightOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-21479: Outer join no filter push down to preserved side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(
x, y.where("a".attr === 1),
x, y.where(IsNotNull('a) && 'a === 1),
LeftOuter)
}
}