Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.plans.{LeftSemiExist, LeftSemiNotExist}
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -55,6 +55,7 @@ class Analyzer(

lazy val batches: Seq[Batch] = Seq(
Batch("Resolution", fixedPoint,
RewriteWhereClause ::
ResolveRelations ::
ResolveReferences ::
ResolveGroupingAnalytics ::
Expand Down Expand Up @@ -202,6 +203,58 @@ class Analyzer(
}
}

/**
* Rewrite the [[Exists]] with left semi join
*/
object RewriteWhereClause extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
// TODO we don't support [EXIST subquery] combined with other conjunctions now
// e.g. from a where exists (select _ from b where a.v = b.v) AND value > '112'
case e @ Exists(
SubqueryConjunction(left, None, None),
Project(_, Filter(condition, right)),
positive) =>
if (positive) {
Join(left, right, LeftSemiExist, Some(condition))
} else {
Join(left, right, LeftSemiNotExist, Some(condition))
}

// TODO we don't support [IN subquery] combined with other conjunctions now
// e.g. where key in (select key from src_b) AND value > '112'
case e @ InSubquery(
SubqueryConjunction(left, Some(key), None),
Project(projectList, Filter(condition, right)),
positive) if e.left.resolved && projectList.length == 1 =>
// possible correlated (subquery references its parent attribute)
// convert the filter as part of the join condition
// and even if it's not correlated, that will not harmful if we
// pop up the filter into the left semi join, cause the optimizer
// will push it back (even down) again.
createSemiJoinForInSubquery(
left, right, And(condition, EqualTo(projectList(0), key)), positive)

case e @ InSubquery(
SubqueryConjunction(left, Some(key), None),
Project(projectList, child),
positive) if e.left.resolved && projectList.length == 1 =>
// it's unrelated
createSemiJoinForInSubquery(left, e.right, EqualTo(projectList(0), key), positive)
}

def createSemiJoinForInSubquery(
left: LogicalPlan,
right: LogicalPlan,
condition: Expression,
positive: Boolean): Join = {
if (positive) {
Join(left, right, LeftSemiExist, Some(condition))
} else {
Join(left, right, LeftSemiNotExist, Some(condition))
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete
* [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.{LeftSemiType, LeftSemiNotExist, LeftSemiExist}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -64,6 +66,9 @@ trait CheckAnalysis {
}

operator match {
case f @ SubqueryConjunction(_, _, Some(condition)) =>
failAnalysis(s"WHERE EXISTS/IS can not be combined with conjunctions, ${condition}")

case f: Filter if f.condition.dataType != BooleanType =>
failAnalysis(
s"filter expression '${f.condition.prettyString}' " +
Expand Down Expand Up @@ -107,6 +112,13 @@ trait CheckAnalysis {
failAnalysis(
s"unresolved operator ${operator.simpleString}")

// We assume the `[NOT] EXISTS` only support the equi-join
// TODO can we support the non-equi-join as well? performance concern?
case o @ Join(_, _, LeftSemiExist | LeftSemiNotExist, _)
if ExtractEquiJoinKeys.unapply(o).isEmpty =>
failAnalysis(
s"condition $o doens't contain any equi-join key")

case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.LeftSemiType
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -146,12 +146,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))

// Eliminate unneeded attributes from right side of a LeftSemiJoin.
case Join(left, right, LeftSemi, condition) =>
case Join(left, right, jt: LeftSemiType, condition) =>
// Collect the list of all references required to evaluate the condition.
val allReferences: AttributeSet =
condition.map(_.references).getOrElse(AttributeSet(Seq.empty))

Join(left, prunedChild(right, allReferences), LeftSemi, condition)
Join(left, prunedChild(right, allReferences), jt, condition)

// Combine adjacent Projects.
case Project(projectList1, Project(projectList2, child)) =>
Expand Down Expand Up @@ -510,11 +510,11 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
* to evaluate them.
* @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth)
*/
private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
protected def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
val (leftEvaluateCondition, rest) =
condition.partition(_.references subsetOf left.outputSet)
condition.partition(_.references subsetOf left.outputSet)
val (rightEvaluateCondition, commonCondition) =
rest.partition(_.references subsetOf right.outputSet)
rest.partition(_.references subsetOf right.outputSet)

(leftEvaluateCondition, rightEvaluateCondition, commonCondition)
}
Expand Down Expand Up @@ -545,7 +545,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {

(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case _ @ (LeftOuter | LeftSemi) =>
case (_ @ LeftOuter | _: LeftSemiType) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
Expand Down Expand Up @@ -581,7 +581,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)

Join(newLeft, newRight, RightOuter, newJoinCond)
case _ @ (LeftOuter | LeftSemi) =>
case (_ @ LeftOuter | _: LeftSemiType) =>
// push down the right side only join filter for right sub query
val newLeft = left
val newRight = rightJoinConditions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,16 @@ case object RightOuter extends JoinType

case object FullOuter extends JoinType

case object LeftSemi extends JoinType
abstract class LeftSemiType extends JoinType {
def exists: Boolean = true
}

case object LeftSemi extends LeftSemiType

// This is for internal used only for the [NOT] EXISTS | [NOT] IN clauses
case object LeftSemiNotExist extends LeftSemiType {
override def exists: Boolean = false
}

case object LeftSemiExist extends LeftSemiType

Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ case class Join(

override def output: Seq[Attribute] = {
joinType match {
case LeftSemi =>
case _: LeftSemiType =>
left.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
Expand Down Expand Up @@ -289,8 +289,34 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}

// Exists, InSubquery & InFilter are used for subquery in WHERE clause ONLY
case class Exists(
left: SubqueryConjunction,
right: LogicalPlan,
positive: Boolean) extends BinaryNode {
override def output: Seq[Attribute] = Nil

override lazy val resolved = false
}

case class InSubquery(left: SubqueryConjunction, right: LogicalPlan, positive: Boolean)
extends BinaryNode {
override def output: Seq[Attribute] = Nil

override lazy val resolved = false
}

// This is only for connect the conjunction and subquery in WHERE clause
// And only used within [[Exists]] and [[InSubquery]], as we want to the
// attribute resolved before mapping the Exists or InSubquery to Left Semi Join
case class SubqueryConjunction(child: LogicalPlan,
key: Option[Expression] = None,
condition: Option[Expression] = None) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
extends UnaryNode {

override def output: Seq[Attribute] = child.output
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
case ExtractEquiJoinKeys(jt: LeftSemiType, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
val semiJoin = joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
leftKeys, rightKeys, planLater(left), planLater(right), jt)
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
case ExtractEquiJoinKeys(jt: LeftSemiType, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
leftKeys, rightKeys, planLater(left), planLater(right), jt)
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
case logical.Join(left, right, jt: LeftSemiType, condition) =>
joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition, jt) :: Nil
case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.sql.execution.joins

import org.apache.spark.annotation.DeveloperApi

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.catalyst.plans.LeftSemiType
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

/**
Expand All @@ -32,7 +34,8 @@ case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {
right: SparkPlan,
jt: LeftSemiType) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

Expand All @@ -59,9 +62,15 @@ case class BroadcastLeftSemiJoinHash(

streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
})
if (jt.exists) {
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
})
} else {
streamIter.filter(current => {
joinKeys(current).anyNull || !broadcastedRelation.value.contains(joinKeys.currentValue)
})
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.LeftSemiType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

Expand All @@ -30,7 +31,10 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
*/
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
streamed: SparkPlan,
broadcast: SparkPlan,
condition: Option[Expression],
jt: LeftSemiType)
extends BinaryNode {
// TODO: Override requiredChildDistribution.

Expand Down Expand Up @@ -68,7 +72,8 @@ case class LeftSemiJoinBNL(
}
i += 1
}
matched

if (jt.exists) matched else !matched
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.sql.execution.joins

import org.apache.spark.annotation.DeveloperApi

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.catalyst.plans.LeftSemiType
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

Expand All @@ -33,7 +35,8 @@ case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {
right: SparkPlan,
jt: LeftSemiType) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

Expand All @@ -60,9 +63,15 @@ case class LeftSemiJoinHash(
}

val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
})
if (jt.exists) {
streamIter.filter(current => {
!joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
})
} else {
streamIter.filter(current => {
joinKeys(current).anyNull || !hashSet.contains(joinKeys.currentValue)
})
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {

// It has a bug and it has been fixed by
// https://issues.apache.org/jira/browse/HIVE-7673 (in Hive 0.14 and trunk).
"input46"
"input46",

// It contains the window function
"subquery_in",
"subquery_notin"
) ++ HiveShim.compatibilityBlackList

/**
Expand Down Expand Up @@ -993,5 +997,5 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"view",
"view_cast",
"view_inputs"
)
) ++ HiveShim.compatibilityWhiteList
}
Loading