Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rebase
  • Loading branch information
adrian-wang committed Aug 6, 2015
commit 71ff4e910b574e2e9ef0b839558abc32569eb193
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ case class SortMergeJoin(

override protected[sql] val trackNumOfRowsEnabled = true

val (streamed, buffered, streamedKeys, bufferedKeys) = joinType match {
val (streamedPlan, bufferedPlan, streamedKeys, bufferedKeys) = joinType match {
case RightOuter => (right, left, rightKeys, leftKeys)
case _ => (left, right, leftKeys, rightKeys)
}
Expand All @@ -64,11 +64,11 @@ case class SortMergeJoin(
override def outputPartitioning: Partitioning = joinType match {
case FullOuter =>
// when doing Full Outer join, NULL rows from both sides are not so partitioned.
UnknownPartitioning(streamed.outputPartitioning.numPartitions)
UnknownPartitioning(streamedPlan.outputPartitioning.numPartitions)
case Inner =>
PartitioningCollection(Seq(streamed.outputPartitioning, buffered.outputPartitioning))
PartitioningCollection(Seq(streamedPlan.outputPartitioning, bufferedPlan.outputPartitioning))
case LeftOuter | rightOuter =>
streamed.outputPartitioning
streamedPlan.outputPartitioning
case x =>
throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType")
}
Expand All @@ -84,28 +84,29 @@ case class SortMergeJoin(
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil

@transient protected lazy val streamedKeyGenerator = newProjection(streamedKeys, streamed.output)
@transient protected lazy val bufferedKeyGenerator = newProjection(bufferedKeys, buffered.output)

// standard null rows
@transient private[this] lazy val streamedNullRow = new GenericRow(streamed.output.length)
@transient private[this] lazy val bufferedNullRow = new GenericRow(buffered.output.length)
@transient protected lazy val streamedKeyGenerator =
newProjection(streamedKeys, streamedPlan.output)
@transient protected lazy val bufferedKeyGenerator =
newProjection(bufferedKeys, bufferedPlan.output)

// checks if the joinedRow can meet condition requirements
@transient private[this] lazy val boundCondition =
condition.map(
newPredicate(_, streamed.output ++ buffered.output)).getOrElse((row: InternalRow) => true)
condition.map(newPredicate(_, streamedPlan.output ++ bufferedPlan.output)).getOrElse(
(row: InternalRow) => true)

private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
// This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
}

protected override def doExecute(): RDD[InternalRow] = {
val streamResults = streamed.execute().map(_.copy())
val bufferResults = buffered.execute().map(_.copy())
val streamResults = streamedPlan.execute().map(_.copy())
val bufferResults = bufferedPlan.execute().map(_.copy())

streamResults.zipPartitions(bufferResults) { (streamedIter, bufferedIter) =>
streamResults.zipPartitions(bufferResults) ( (streamedIter, bufferedIter) => {
// standard null rows
val streamedNullRow = new GenericRow(streamedPlan.output.length)
val bufferedNullRow = new GenericRow(bufferedPlan.output.length)
new Iterator[InternalRow] {
// An ordering that can be used to compare keys from both sides.
private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
Expand Down Expand Up @@ -140,13 +141,13 @@ case class SortMergeJoin(
if (hasNext) {
if (bufferedMatches == null || bufferedMatches.size == 0) {
// we just found a row with no join match and we are here to produce a row
// with this row with a standard null row from the other side.
// with this row and a standard null row from the other side.
if (continueStreamed) {
val joinedRow = smartJoinRow(streamedElement, bufferedNullRow.copy())
val joinedRow = smartJoinRow(streamedElement, bufferedNullRow)
fetchStreamed()
joinedRow
} else {
val joinedRow = smartJoinRow(streamedNullRow.copy(), bufferedElement)
val joinedRow = smartJoinRow(streamedNullRow, bufferedElement)
fetchBuffered()
joinedRow
}
Expand Down Expand Up @@ -186,7 +187,7 @@ case class SortMergeJoin(
case _ => joinRow(streamedRow, bufferedRow)
}

private def fetchStreamed() = {
private def fetchStreamed(): Unit = {
if (streamedIter.hasNext) {
streamedElement = streamedIter.next()
streamedKey = streamedKeyGenerator(streamedElement)
Expand All @@ -195,7 +196,7 @@ case class SortMergeJoin(
}
}

private def fetchBuffered() = {
private def fetchBuffered(): Unit = {
if (bufferedIter.hasNext) {
bufferedElement = bufferedIter.next()
bufferedKey = bufferedKeyGenerator(bufferedElement)
Expand All @@ -215,6 +216,8 @@ case class SortMergeJoin(
* When this is not a Inner join, we will also return true when we get a row with no match
* on the other side. This search will jump out every time from the same position until
* `next()` is called.
* Unless we call `next()`, this function can be called multiple times, with the same
* return value and result as running it once, since we have set guardians in it.
*
* @return true if the search is successful, and false if the right iterator runs out of
* tuples.
Expand Down Expand Up @@ -259,22 +262,22 @@ case class SortMergeJoin(
if (boundCondition(joinRow(streamedElement, bufferedElement))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we also need to keep the row when key match but boundCondition return false for next stream row update.

bufferedMatches += bufferedElement
} else if (joinType == FullOuter) {
bufferedMatches += bufferedNullRow.copy()
bufferedMatches += bufferedNullRow
secondBufferedMatches += bufferedElement
}
fetchBuffered()
stop =
keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null
}
if (bufferedMatches.size == 0 && joinType != Inner) {
bufferedMatches += bufferedNullRow.copy()
bufferedMatches += bufferedNullRow
}
if (bufferedMatches.size > 0) {
bufferedPosition = 0
matchKey = streamedKey
// secondBufferedMatches.size cannot be larger than bufferedMatches
if (secondBufferedMatches.size > 0) {
secondStreamedElement = streamedNullRow.copy()
secondStreamedElement = streamedNullRow
}
}
}
Expand All @@ -299,6 +302,6 @@ case class SortMergeJoin(
bufferedMatches != null && bufferedMatches.size > 0
}
}
}
})
}
}