Skip to content
Closed
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
sameeragarwal committed Jan 20, 2016
commit 04ff99ab96957f57c74ae24835b4bfcdd27e06b8
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]]
/**
* Extracts the output property from a given child.
*/
def extractConstraintsFromChild(child: QueryPlan[PlanType]): Seq[Expression] = {
def extractConstraintsFromChild(child: QueryPlan[PlanType]): Set[Expression] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protected?

Also I'm not sure I get the scala doc. Maybe getReleventContraints is a better name? It is taking the constraints and removing those that don't apply anymore because we removed columns right?

child.constraints.filter(_.references.subsetOf(outputSet))
}

/**
* An sequence of expressions that describes the data property of the output rows of this
* operator. For example, if the output of this operator is column `a`, an example `constraints`
* can be `Seq(a > 10, a < 20)`.
* can be `Set(a > 10, a < 20)`.
*/
def constraints: Seq[Expression] = Nil
def constraints: Set[Expression] = Set.empty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably going to be nontrivial to calculate for a large tree. We might consider having an internal method, private def validConstraints or something, that we expand / canonicalize into a lazy val constraints


/**
* Returns the set of attributes that are output by this node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ abstract class UnaryNode extends LogicalPlan with PredicateHelper {

override def children: Seq[LogicalPlan] = child :: Nil

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
extractConstraintsFromChild(child)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ case class Generate(
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
val newConstraint = splitConjunctivePredicates(condition).filter(
_.references.subsetOf(outputSet))
_.references.subsetOf(outputSet)).toSet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: we typically avoid breaking in the middle of a function call and instead prefer to break in between calls (always pick the highest syntactic level)

val newConstraint = splitConjunctivePredicates(condition)
  .filter(_.references.subsetOf(outputSet))
  .toSet

newConstraint.union(extractConstraintsFromChild(child))
}
}
Expand All @@ -103,9 +103,9 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar
leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
}

protected def leftConstraints: Seq[Expression] = extractConstraintsFromChild(left)
protected def leftConstraints: Set[Expression] = extractConstraintsFromChild(left)

protected def rightConstraints: Seq[Expression] = {
protected def rightConstraints: Set[Expression] = {
require(left.output.size == right.output.size)
val attributeRewrites = AttributeMap(left.output.zip(right.output))
extractConstraintsFromChild(right).map(_ transform {
Expand Down Expand Up @@ -135,7 +135,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef
Statistics(sizeInBytes = sizeInBytes)
}

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
leftConstraints.intersect(rightConstraints)
}
}
Expand All @@ -147,7 +147,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
leftConstraints.union(rightConstraints)
}
}
Expand All @@ -156,7 +156,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output

override def constraints: Seq[Expression] = leftConstraints
override def constraints: Set[Expression] = leftConstraints
}

case class Join(
Expand All @@ -180,7 +180,7 @@ case class Join(
}
}

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
joinType match {
case LeftSemi =>
extractConstraintsFromChild(left)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._

/**
* This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly
Expand Down Expand Up @@ -75,27 +74,14 @@ class LogicalPlanSuite extends SparkFunSuite {
}

test("propagating constraint in filter") {

def resolve(plan: LogicalPlan, constraints: Seq[String]): Seq[Expression] = {
Seq(plan.resolve(constraints.map(_.toString), caseInsensitiveResolution).get)
}

val tr = LocalRelation('a.int, 'b.string, 'c.int)
def resolveColumn(columnName: String): Expression =
tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
assert(tr.analyze.constraints.isEmpty)
assert(tr.select('a.attr).analyze.constraints.isEmpty)
assert(tr.where('a.attr > 10).analyze.constraints.zip(Seq('a.attr > 10))
.forall(e => e._1.semanticEquals(e._2)))
/*
assert(tr.where('a.attr > 10).analyze.constraints == resolve(tr.where('a.attr > 10).analyze,
Seq("a > 10")))
*/
/*
assert(logicalPlan.constraints ==
Seq(logicalPlan.resolve(Seq('a > 10), caseInsensitiveResolution))
assert(tr.where('a.attr > 10).select('c.attr).analyze.constraints.get == ('a > 10))
assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
.analyze.constraints.get == And('a > 10, 'c < 100))
assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn("a") > 10))
assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
*/
assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
.analyze.constraints == Set(resolveColumn("a") > 10, resolveColumn("c") < 100))
}
}