Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
use sort merge join for outer join
  • Loading branch information
adrian-wang committed Aug 6, 2015
commit d95417ebb9f4a70e945615af93afb478cc1ac135
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)

// If the sort merge join option is set, we want to use sort merge join prior to hashjoin
// for now let's support inner join first, then add outer join
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
// If the sort merge join option is set, we want to use sort merge join prior to hashjoin.
// And for outer join, we can not put conditions outside of the join
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
val mergeJoin =
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
joins.SortMergeJoin(
leftKeys, rightKeys, joinType, planLater(left), planLater(right), condition) :: Nil

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we move the code to a new Strategy(like SortMergeJoin) instead of mix in Hashjoin?

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.collection.CompactBuffer
Expand All @@ -35,50 +36,100 @@ import org.apache.spark.util.collection.CompactBuffer
case class SortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
right: SparkPlan,
condition: Option[Expression] = None) extends BinaryNode {

override protected[sql] val trackNumOfRowsEnabled = true

override def output: Seq[Attribute] = left.output ++ right.output
val (streamed, buffered, streamedKeys, bufferedKeys) = joinType match {
case RightOuter => (right, left, rightKeys, leftKeys)
case _ => (left, right, leftKeys, rightKeys)
}

override def output: Seq[Attribute] = joinType match {
case Inner =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case x =>
throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType")
}

override def outputPartitioning: Partitioning =
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
override def outputPartitioning: Partitioning = joinType match {
case FullOuter =>
// when doing Full Outer join, NULL rows from both sides are not so partitioned.
UnknownPartitioning(streamed.outputPartitioning.numPartitions)
case Inner =>
PartitioningCollection(Seq(streamed.outputPartitioning, buffered.outputPartitioning))
case LeftOuter | rightOuter =>
streamed.outputPartitioning
case x =>
throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType")
}

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
override def outputOrdering: Seq[SortOrder] = joinType match {
case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered.
case _ => requiredOrders(streamedKeys)
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil

@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
@transient protected lazy val streamedKeyGenerator = newProjection(streamedKeys, streamed.output)
@transient protected lazy val bufferedKeyGenerator = newProjection(bufferedKeys, buffered.output)

// standard null rows
@transient private[this] lazy val streamedNullRow = new GenericRow(streamed.output.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move streamedNullRow and bufferedNullRow into the zipPartitions call so that we don't have to make a transient lazy val? I find transient lazy val to be a bit confusing and like to avoid it when I can.

@transient private[this] lazy val bufferedNullRow = new GenericRow(buffered.output.length)

// checks if the joinedRow can meet condition requirements
@transient private[this] lazy val boundCondition =
condition.map(
newPredicate(_, streamed.output ++ buffered.output)).getOrElse((row: InternalRow) => true)

private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
// This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
}

protected override def doExecute(): RDD[InternalRow] = {
val leftResults = left.execute().map(_.copy())
val rightResults = right.execute().map(_.copy())
val streamResults = streamed.execute().map(_.copy())
val bufferResults = buffered.execute().map(_.copy())

leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
streamResults.zipPartitions(bufferResults) { (streamedIter, bufferedIter) =>
new Iterator[InternalRow] {
// An ordering that can be used to compare keys from both sides.
private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
// Mutable per row objects.
private[this] val joinRow = new JoinedRow
private[this] var leftElement: InternalRow = _
private[this] var rightElement: InternalRow = _
private[this] var leftKey: InternalRow = _
private[this] var rightKey: InternalRow = _
private[this] var rightMatches: CompactBuffer[InternalRow] = _
private[this] var rightPosition: Int = -1
private[this] val joinRow = new JoinedRow5
private[this] var streamedElement: InternalRow = _
private[this] var bufferedElement: InternalRow = _
private[this] var streamedKey: InternalRow = _
private[this] var bufferedKey: InternalRow = _
private[this] var bufferedMatches: CompactBuffer[InternalRow] = _
private[this] var bufferedPosition: Int = -1
private[this] var stop: Boolean = false
private[this] var matchKey: InternalRow = _
// when we do merge algorithm and find some not matched join key, there must be a side
// that do not have a corresponding match. So we need to mark which side it is. True means
// streamed side not have match, and False means the buffered side. Only set when needed.
private[this] var continueStreamed: Boolean = _
// when we do full outer join and find all matched keys, we put a null stream row into
// this to tell next() that we need to combine null stream row with all rows that not match
// conditions.
private[this] var secondStreamedElement: InternalRow = _
// Stores rows that match the join key but not match conditions.
// These rows will be useful when we are doing Full Outer Join.
private[this] var secondBufferedMatches: CompactBuffer[InternalRow] = _

// initialize iterator
initialize()
Expand All @@ -87,84 +138,165 @@ case class SortMergeJoin(

override final def next(): InternalRow = {
if (hasNext) {
// we are using the buffered right rows and run down left iterator
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
rightPosition += 1
if (rightPosition >= rightMatches.size) {
rightPosition = 0
fetchLeft()
if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
stop = false
rightMatches = null
if (bufferedMatches == null || bufferedMatches.size == 0) {
// we just found a row with no join match and we are here to produce a row
// with this row with a standard null row from the other side.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo? Maybe this meant to read "with this row and a standard null row"?

if (continueStreamed) {
val joinedRow = smartJoinRow(streamedElement, bufferedNullRow.copy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to copy the standard null rows? Are we worried about downstream operators mutating them?

fetchStreamed()
joinedRow
} else {
val joinedRow = smartJoinRow(streamedNullRow.copy(), bufferedElement)
fetchBuffered()
joinedRow
}
} else {
// we are using the buffered right rows and run down left iterator
val joinedRow = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition))
bufferedPosition += 1
if (bufferedPosition >= bufferedMatches.size) {
bufferedPosition = 0
if (joinType != FullOuter || secondStreamedElement == null) {
fetchStreamed()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use boundCondition to update bufferedMatches after we fetchStreamed () .Otherwise we may get wrong answer.For example

table a(key int,value int);table b(key int,value int)
data of a
1 3
1 1
2 1
2 3

data of b
1 1
2 1
select a.key,b.key,a.value-b.value from a left outer join b on a.key=b.key and a.value - b.value > 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I'll rewrite this part.

if (streamedElement == null || keyOrdering.compare(streamedKey, matchKey) != 0) {
stop = false
bufferedMatches = null
}
} else {
// in FullOuter join and the first time we finish the match buffer,
// we still want to generate all rows with streamed null row and buffered
// rows that match the join key but not the conditions.
streamedElement = secondStreamedElement
bufferedMatches = secondBufferedMatches
secondStreamedElement = null
secondBufferedMatches = null
}
}
joinedRow
}
joinedRow
} else {
// no more result
throw new NoSuchElementException
}
}

private def fetchLeft() = {
if (leftIter.hasNext) {
leftElement = leftIter.next()
leftKey = leftKeyGenerator(leftElement)
private def smartJoinRow(streamedRow: InternalRow, bufferedRow: InternalRow): InternalRow =
joinType match {
case RightOuter => joinRow(bufferedRow, streamedRow)
case _ => joinRow(streamedRow, bufferedRow)
}

private def fetchStreamed() = {
if (streamedIter.hasNext) {
streamedElement = streamedIter.next()
streamedKey = streamedKeyGenerator(streamedElement)
} else {
leftElement = null
streamedElement = null
}
}

private def fetchRight() = {
if (rightIter.hasNext) {
rightElement = rightIter.next()
rightKey = rightKeyGenerator(rightElement)
private def fetchBuffered() = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: we should add : Unit = if we always ignore the return value.

if (bufferedIter.hasNext) {
bufferedElement = bufferedIter.next()
bufferedKey = bufferedKeyGenerator(bufferedElement)
} else {
rightElement = null
bufferedElement = null
}
}

private def initialize() = {
fetchLeft()
fetchRight()
fetchStreamed()
fetchBuffered()
}

/**
* Searches the right iterator for the next rows that have matches in left side, and store
* them in a buffer.
* When this is not a Inner join, we will also return true when we get a row with no match
* on the other side. This search will jump out every time from the same position until
* `next()` is called.
*
* @return true if the search is successful, and false if the right iterator runs out of
* tuples.
*/
private def nextMatchingPair(): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading through the old version of the code, one tricky thing that stood out to me was the fact that if a consumer of this iterator calls hasNext() followed by next(), then we end up calling nextMatchingPair() two times in a row. It might be nice to add a comment here to explain why this is safe / correct.

if (!stop && rightElement != null) {
// run both side to get the first match pair
while (!stop && leftElement != null && rightElement != null) {
val comparing = keyOrdering.compare(leftKey, rightKey)
if (!stop && streamedElement != null) {
// step 1: run both side to get the first match pair
while (!stop && streamedElement != null && bufferedElement != null) {
val comparing = keyOrdering.compare(streamedKey, bufferedKey)
// for inner join, we need to filter those null keys
stop = comparing == 0 && !leftKey.anyNull
if (comparing > 0 || rightKey.anyNull) {
fetchRight()
} else if (comparing < 0 || leftKey.anyNull) {
fetchLeft()
stop = comparing == 0 && !streamedKey.anyNull
if (comparing > 0 || bufferedKey.anyNull) {
if (joinType == FullOuter) {
// the join type is full outer and the buffered side has a row with no
// join match, so we have a result row with streamed null with buffered
// side as this row. Then we fetch next buffered element and go back.
continueStreamed = false
return true
} else {
fetchBuffered()
}
} else if (comparing < 0 || streamedKey.anyNull) {
if (joinType == Inner) {
fetchStreamed()
} else {
// the join type is not inner and the streamed side has a row with no
// join match, so we have a result row with this streamed row with buffered
// null row. Then we fetch next streamed element and go back.
continueStreamed = true
return true
}
}
}
rightMatches = new CompactBuffer[InternalRow]()
// step 2: run down the buffered side to put all matched rows in a buffer
bufferedMatches = new CompactBuffer[InternalRow]()
secondBufferedMatches = new CompactBuffer[InternalRow]()
if (stop) {
stop = false
// iterate the right side to buffer all rows that matches
// as the records should be ordered, exit when we meet the first that not match
while (!stop && rightElement != null) {
rightMatches += rightElement
fetchRight()
stop = keyOrdering.compare(leftKey, rightKey) != 0
while (!stop) {
if (boundCondition(joinRow(streamedElement, bufferedElement))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we also need to keep the row when key match but boundCondition return false for next stream row update.

bufferedMatches += bufferedElement
} else if (joinType == FullOuter) {
bufferedMatches += bufferedNullRow.copy()
secondBufferedMatches += bufferedElement
}
fetchBuffered()
stop =
keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null
}
if (bufferedMatches.size == 0 && joinType != Inner) {
bufferedMatches += bufferedNullRow.copy()
}
if (bufferedMatches.size > 0) {
bufferedPosition = 0
matchKey = streamedKey
// secondBufferedMatches.size cannot be larger than bufferedMatches
if (secondBufferedMatches.size > 0) {
secondStreamedElement = streamedNullRow.copy()
}
}
}
}
// `stop` is false iff left or right has finished iteration in step 1.
// if we get into step 2, `stop` cannot be false.
if (!stop && (bufferedMatches == null || bufferedMatches.size == 0)) {
if (streamedElement == null && bufferedElement != null) {
// streamedElement == null but bufferedElement != null
if (joinType == FullOuter) {
continueStreamed = false
return true
}
if (rightMatches.size > 0) {
rightPosition = 0
matchKey = leftKey
} else if (streamedElement != null && bufferedElement == null) {
// bufferedElement == null but streamedElement != null
if (joinType != Inner) {
continueStreamed = true
return true
}
}
}
rightMatches != null && rightMatches.size > 0
bufferedMatches != null && bufferedMatches.size > 0
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Seq(
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin])
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[SortMergeJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[SortMergeJoin]),
("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
} finally {
ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
Expand Down