@@ -36,46 +36,91 @@ import org.apache.spark.util.collection.CompactBuffer
3636case class SortMergeJoin (
3737 leftKeys : Seq [Expression ],
3838 rightKeys : Seq [Expression ],
39+ joinType : JoinType ,
3940 left : SparkPlan ,
40- right : SparkPlan ) extends BinaryNode {
41+ right : SparkPlan ,
42+ condition : Option [Expression ] = None ) extends BinaryNode {
4143
42- override def output : Seq [Attribute ] = left.output ++ right.output
44+ val (streamed, buffered, streamedKeys, bufferedKeys) = joinType match {
45+ case RightOuter => (right, left, rightKeys, leftKeys)
46+ case _ => (left, right, leftKeys, rightKeys)
47+ }
48+
49+ override def output : Seq [Attribute ] = joinType match {
50+ case Inner =>
51+ left.output ++ right.output
52+ case LeftOuter =>
53+ left.output ++ right.output.map(_.withNullability(true ))
54+ case RightOuter =>
55+ left.output.map(_.withNullability(true )) ++ right.output
56+ case FullOuter =>
57+ left.output.map(_.withNullability(true )) ++ right.output.map(_.withNullability(true ))
58+ case x =>
59+ throw new Exception (s " SortMergeJoin should not take $x as the JoinType " )
60+ }
4361
44- override def outputPartitioning : Partitioning = left.outputPartitioning
62+ override def outputPartitioning : Partitioning = joinType match {
63+ case FullOuter =>
64+ // when doing Full Outer join, NULL rows from both sides are not so partitioned.
65+ UnknownPartitioning (streamed.outputPartitioning.numPartitions)
66+ case _ => streamed.outputPartitioning
67+ }
4568
4669 override def requiredChildDistribution : Seq [Distribution ] =
4770 ClusteredDistribution (leftKeys) :: ClusteredDistribution (rightKeys) :: Nil
4871
4972 // this is to manually construct an ordering that can be used to compare keys from both sides
50- private val keyOrdering : RowOrdering = RowOrdering .forSchema(leftKeys .map(_.dataType))
73+ private val keyOrdering : RowOrdering = RowOrdering .forSchema(streamedKeys .map(_.dataType))
5174
52- override def outputOrdering : Seq [SortOrder ] = requiredOrders(leftKeys)
75+ override def outputOrdering : Seq [SortOrder ] = joinType match {
76+ case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered.
77+ case _ => requiredOrders(streamedKeys)
78+ }
5379
5480 override def requiredChildOrdering : Seq [Seq [SortOrder ]] =
5581 requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
5682
57- @ transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
58- @ transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
83+ @ transient protected lazy val streamedKeyGenerator = newProjection(streamedKeys, streamed.output)
84+ @ transient protected lazy val bufferedKeyGenerator = newProjection(bufferedKeys, buffered.output)
85+
86+ // standard null rows
87+ @ transient private [this ] lazy val streamedNullRow = new GenericRow (streamed.output.length)
88+ @ transient private [this ] lazy val bufferedNullRow = new GenericRow (buffered.output.length)
89+
90+ // checks if the joinedRow can meet condition requirements
91+ @ transient private [this ] lazy val boundCondition =
92+ condition.map(newPredicate(_, streamed.output ++ buffered.output)).getOrElse((row : Row ) => true )
5993
6094 private def requiredOrders (keys : Seq [Expression ]): Seq [SortOrder ] =
6195 keys.map(SortOrder (_, Ascending ))
6296
6397 override def execute (): RDD [Row ] = {
64- val leftResults = left .execute().map(_.copy())
65- val rightResults = right .execute().map(_.copy())
98+ val streamResults = streamed .execute().map(_.copy())
99+ val bufferResults = buffered .execute().map(_.copy())
66100
67- leftResults .zipPartitions(rightResults ) { (leftIter, rightIter ) =>
101+ streamResults .zipPartitions(bufferResults ) { (streamedIter, bufferedIter ) =>
68102 new Iterator [Row ] {
69103 // Mutable per row objects.
70104 private [this ] val joinRow = new JoinedRow5
71- private [this ] var leftElement : Row = _
72- private [this ] var rightElement : Row = _
73- private [this ] var leftKey : Row = _
74- private [this ] var rightKey : Row = _
75- private [this ] var rightMatches : CompactBuffer [Row ] = _
76- private [this ] var rightPosition : Int = - 1
105+ private [this ] var streamedElement : Row = _
106+ private [this ] var bufferedElement : Row = _
107+ private [this ] var streamedKey : Row = _
108+ private [this ] var bufferedKey : Row = _
109+ private [this ] var bufferedMatches : CompactBuffer [Row ] = _
110+ private [this ] var bufferedPosition : Int = - 1
77111 private [this ] var stop : Boolean = false
78112 private [this ] var matchKey : Row = _
113+ // when we do merge algorithm and find some not matched join key, there must be a side
114+ // that do not have a corresponding match. So we need to mark which side it is. True means
115+ // streamed side not have match, and False means the buffered side. Only set when needed.
116+ private [this ] var continueStreamed : Boolean = _
117+ // when we do full outer join and find all matched keys, we put a null stream row into
118+ // this to tell next() that we need to combine null stream row with all rows that not match
119+ // conditions.
120+ private [this ] var secondStreamedElement : Row = _
121+ // Stores rows that match the join key but not match conditions.
122+ // These rows will be useful when we are doing Full Outer Join.
123+ private [this ] var secondBufferedMatches : CompactBuffer [Row ] = _
79124
80125 // initialize iterator
81126 initialize()
@@ -84,84 +129,164 @@ case class SortMergeJoin(
84129
85130 override final def next (): Row = {
86131 if (hasNext) {
87- // we are using the buffered right rows and run down left iterator
88- val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
89- rightPosition += 1
90- if (rightPosition >= rightMatches.size) {
91- rightPosition = 0
92- fetchLeft()
93- if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0 ) {
94- stop = false
95- rightMatches = null
132+ if (bufferedMatches == null || bufferedMatches.size == 0 ) {
133+ // we just found a row with no join match and we are here to produce a row
134+ // with this row with a standard null row from the other side.
135+ if (continueStreamed) {
136+ val joinedRow = smartJoinRow(streamedElement, bufferedNullRow.copy())
137+ fetchStreamed()
138+ joinedRow
139+ } else {
140+ val joinedRow = smartJoinRow(streamedNullRow.copy(), bufferedElement)
141+ fetchBuffered()
142+ joinedRow
143+ }
144+ } else {
145+ // we are using the buffered right rows and run down left iterator
146+ val joinedRow = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition))
147+ bufferedPosition += 1
148+ if (bufferedPosition >= bufferedMatches.size) {
149+ bufferedPosition = 0
150+ if (joinType != FullOuter || secondStreamedElement == null ) {
151+ fetchStreamed()
152+ if (streamedElement == null || keyOrdering.compare(streamedKey, matchKey) != 0 ) {
153+ stop = false
154+ bufferedMatches = null
155+ }
156+ } else {
157+ // in FullOuter join and the first time we finish the match buffer,
158+ // we still want to generate all rows with streamed null row and buffered
159+ // rows that match the join key but not the conditions.
160+ streamedElement = secondStreamedElement
161+ bufferedMatches = secondBufferedMatches
162+ secondStreamedElement = null
163+ secondBufferedMatches = null
164+ }
96165 }
166+ joinedRow
97167 }
98- joinedRow
99168 } else {
100169 // no more result
101170 throw new NoSuchElementException
102171 }
103172 }
104173
105- private def fetchLeft () = {
106- if (leftIter.hasNext) {
107- leftElement = leftIter.next()
108- leftKey = leftKeyGenerator(leftElement)
174+ private def smartJoinRow (streamedRow : Row , bufferedRow : Row ): Row = joinType match {
175+ case RightOuter => joinRow(bufferedRow, streamedRow)
176+ case _ => joinRow(streamedRow, bufferedRow)
177+ }
178+
179+ private def fetchStreamed () = {
180+ if (streamedIter.hasNext) {
181+ streamedElement = streamedIter.next()
182+ streamedKey = streamedKeyGenerator(streamedElement)
109183 } else {
110- leftElement = null
184+ streamedElement = null
111185 }
112186 }
113187
114- private def fetchRight () = {
115- if (rightIter .hasNext) {
116- rightElement = rightIter .next()
117- rightKey = rightKeyGenerator(rightElement )
188+ private def fetchBuffered () = {
189+ if (bufferedIter .hasNext) {
190+ bufferedElement = bufferedIter .next()
191+ bufferedKey = bufferedKeyGenerator(bufferedElement )
118192 } else {
119- rightElement = null
193+ bufferedElement = null
120194 }
121195 }
122196
123197 private def initialize () = {
124- fetchLeft ()
125- fetchRight ()
198+ fetchStreamed ()
199+ fetchBuffered ()
126200 }
127201
128202 /**
129203 * Searches the right iterator for the next rows that have matches in left side, and store
130204 * them in a buffer.
205+ * When this is not a Inner join, we will also return true when we get a row with no match
206+ * on the other side. This search will jump out every time from the same position until
207+ * `next()` is called.
131208 *
132209 * @return true if the search is successful, and false if the right iterator runs out of
133210 * tuples.
134211 */
135212 private def nextMatchingPair (): Boolean = {
136- if (! stop && rightElement != null ) {
137- // run both side to get the first match pair
138- while (! stop && leftElement != null && rightElement != null ) {
139- val comparing = keyOrdering.compare(leftKey, rightKey )
213+ if (! stop && streamedElement != null ) {
214+ // step 1: run both side to get the first match pair
215+ while (! stop && streamedElement != null && bufferedElement != null ) {
216+ val comparing = keyOrdering.compare(streamedKey, bufferedKey )
140217 // for inner join, we need to filter those null keys
141- stop = comparing == 0 && ! leftKey.anyNull
142- if (comparing > 0 || rightKey.anyNull) {
143- fetchRight()
144- } else if (comparing < 0 || leftKey.anyNull) {
145- fetchLeft()
218+ stop = comparing == 0 && ! streamedKey.anyNull
219+ if (comparing > 0 || bufferedKey.anyNull) {
220+ if (joinType == FullOuter ) {
221+ // the join type is full outer and the buffered side has a row with no
222+ // join match, so we have a result row with streamed null with buffered
223+ // side as this row. Then we fetch next buffered element and go back.
224+ continueStreamed = false
225+ return true
226+ } else {
227+ fetchBuffered()
228+ }
229+ } else if (comparing < 0 || streamedKey.anyNull) {
230+ if (joinType == Inner ) {
231+ fetchStreamed()
232+ } else {
233+ // the join type is not inner and the streamed side has a row with no
234+ // join match, so we have a result row with this streamed row with buffered
235+ // null row. Then we fetch next streamed element and go back.
236+ continueStreamed = true
237+ return true
238+ }
146239 }
147240 }
148- rightMatches = new CompactBuffer [Row ]()
241+ // step 2: run down the buffered side to put all matched rows in a buffer
242+ bufferedMatches = new CompactBuffer [Row ]()
243+ secondBufferedMatches = new CompactBuffer [Row ]()
149244 if (stop) {
150245 stop = false
151246 // iterate the right side to buffer all rows that matches
152247 // as the records should be ordered, exit when we meet the first that not match
153- while (! stop && rightElement != null ) {
154- rightMatches += rightElement
155- fetchRight()
156- stop = keyOrdering.compare(leftKey, rightKey) != 0
248+ while (! stop) {
249+ if (boundCondition(joinRow(streamedElement, bufferedElement))) {
250+ bufferedMatches += bufferedElement
251+ } else if (joinType == FullOuter ) {
252+ bufferedMatches += bufferedNullRow.copy()
253+ secondBufferedMatches += bufferedElement
254+ }
255+ fetchBuffered()
256+ stop =
257+ keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null
258+ }
259+ if (bufferedMatches.size == 0 && joinType != Inner ) {
260+ bufferedMatches += bufferedNullRow.copy()
261+ }
262+ if (bufferedMatches.size > 0 ) {
263+ bufferedPosition = 0
264+ matchKey = streamedKey
265+ // secondBufferedMatches.size cannot be larger than bufferedMatches
266+ if (secondBufferedMatches.size > 0 ) {
267+ secondStreamedElement = streamedNullRow.copy()
268+ }
269+ }
270+ }
271+ }
272+ // `stop` is false iff left or right has finished iteration in step 1.
273+ // if we get into step 2, `stop` cannot be false.
274+ if (! stop && (bufferedMatches == null || bufferedMatches.size == 0 )) {
275+ if (streamedElement == null && bufferedElement != null ) {
276+ // streamedElement == null but bufferedElement != null
277+ if (joinType == FullOuter ) {
278+ continueStreamed = false
279+ return true
157280 }
158- if (rightMatches.size > 0 ) {
159- rightPosition = 0
160- matchKey = leftKey
281+ } else if (streamedElement != null && bufferedElement == null ) {
282+ // bufferedElement == null but streamedElement != null
283+ if (joinType != Inner ) {
284+ continueStreamed = true
285+ return true
161286 }
162287 }
163288 }
164- rightMatches != null && rightMatches .size > 0
289+ bufferedMatches != null && bufferedMatches .size > 0
165290 }
166291 }
167292 }
0 commit comments