Skip to content

Commit fc862f4

Browse files
committed
use sort merge join for outer join
1 parent bdc5c16 commit fc862f4

File tree

2 files changed

+186
-62
lines changed

2 files changed

+186
-62
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
9090
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
9191
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
9292

93-
// If the sort merge join option is set, we want to use sort merge join prior to hashjoin
94-
// for now let's support inner join first, then add outer join
95-
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
93+
// If the sort merge join option is set, we want to use sort merge join prior to hashjoin.
94+
// And for outer join, we can not put conditions outside of the join
95+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
9696
if sqlContext.conf.sortMergeJoinEnabled =>
97-
val mergeJoin =
98-
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
99-
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
97+
joins.SortMergeJoin(
98+
leftKeys, rightKeys, joinType, planLater(left), planLater(right), condition) :: Nil
10099

101100
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
102101
val buildSide =

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 181 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -36,46 +36,91 @@ import org.apache.spark.util.collection.CompactBuffer
3636
case class SortMergeJoin(
3737
leftKeys: Seq[Expression],
3838
rightKeys: Seq[Expression],
39+
joinType: JoinType,
3940
left: SparkPlan,
40-
right: SparkPlan) extends BinaryNode {
41+
right: SparkPlan,
42+
condition: Option[Expression] = None) extends BinaryNode {
4143

42-
override def output: Seq[Attribute] = left.output ++ right.output
44+
val (streamed, buffered, streamedKeys, bufferedKeys) = joinType match {
45+
case RightOuter => (right, left, rightKeys, leftKeys)
46+
case _ => (left, right, leftKeys, rightKeys)
47+
}
48+
49+
override def output: Seq[Attribute] = joinType match {
50+
case Inner =>
51+
left.output ++ right.output
52+
case LeftOuter =>
53+
left.output ++ right.output.map(_.withNullability(true))
54+
case RightOuter =>
55+
left.output.map(_.withNullability(true)) ++ right.output
56+
case FullOuter =>
57+
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
58+
case x =>
59+
throw new Exception(s"SortMergeJoin should not take $x as the JoinType")
60+
}
4361

44-
override def outputPartitioning: Partitioning = left.outputPartitioning
62+
override def outputPartitioning: Partitioning = joinType match {
63+
case FullOuter =>
64+
// when doing Full Outer join, NULL rows from both sides are not so partitioned.
65+
UnknownPartitioning(streamed.outputPartitioning.numPartitions)
66+
case _ => streamed.outputPartitioning
67+
}
4568

4669
override def requiredChildDistribution: Seq[Distribution] =
4770
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4871

4972
// this is to manually construct an ordering that can be used to compare keys from both sides
50-
private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))
73+
private val keyOrdering: RowOrdering = RowOrdering.forSchema(streamedKeys.map(_.dataType))
5174

52-
override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
75+
override def outputOrdering: Seq[SortOrder] = joinType match {
76+
case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered.
77+
case _ => requiredOrders(streamedKeys)
78+
}
5379

5480
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
5581
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
5682

57-
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
58-
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
83+
@transient protected lazy val streamedKeyGenerator = newProjection(streamedKeys, streamed.output)
84+
@transient protected lazy val bufferedKeyGenerator = newProjection(bufferedKeys, buffered.output)
85+
86+
// standard null rows
87+
@transient private[this] lazy val streamedNullRow = new GenericRow(streamed.output.length)
88+
@transient private[this] lazy val bufferedNullRow = new GenericRow(buffered.output.length)
89+
90+
// checks if the joinedRow can meet condition requirements
91+
@transient private[this] lazy val boundCondition =
92+
condition.map(newPredicate(_, streamed.output ++ buffered.output)).getOrElse((row: Row) => true)
5993

6094
private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
6195
keys.map(SortOrder(_, Ascending))
6296

6397
override def execute(): RDD[Row] = {
64-
val leftResults = left.execute().map(_.copy())
65-
val rightResults = right.execute().map(_.copy())
98+
val streamResults = streamed.execute().map(_.copy())
99+
val bufferResults = buffered.execute().map(_.copy())
66100

67-
leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
101+
streamResults.zipPartitions(bufferResults) { (streamedIter, bufferedIter) =>
68102
new Iterator[Row] {
69103
// Mutable per row objects.
70104
private[this] val joinRow = new JoinedRow5
71-
private[this] var leftElement: Row = _
72-
private[this] var rightElement: Row = _
73-
private[this] var leftKey: Row = _
74-
private[this] var rightKey: Row = _
75-
private[this] var rightMatches: CompactBuffer[Row] = _
76-
private[this] var rightPosition: Int = -1
105+
private[this] var streamedElement: Row = _
106+
private[this] var bufferedElement: Row = _
107+
private[this] var streamedKey: Row = _
108+
private[this] var bufferedKey: Row = _
109+
private[this] var bufferedMatches: CompactBuffer[Row] = _
110+
private[this] var bufferedPosition: Int = -1
77111
private[this] var stop: Boolean = false
78112
private[this] var matchKey: Row = _
113+
// when we do merge algorithm and find some not matched join key, there must be a side
114+
// that do not have a corresponding match. So we need to mark which side it is. True means
115+
// streamed side not have match, and False means the buffered side. Only set when needed.
116+
private[this] var continueStreamed: Boolean = _
117+
// when we do full outer join and find all matched keys, we put a null stream row into
118+
// this to tell next() that we need to combine null stream row with all rows that not match
119+
// conditions.
120+
private[this] var secondStreamedElement: Row = _
121+
// Stores rows that match the join key but not match conditions.
122+
// These rows will be useful when we are doing Full Outer Join.
123+
private[this] var secondBufferedMatches: CompactBuffer[Row] = _
79124

80125
// initialize iterator
81126
initialize()
@@ -84,84 +129,164 @@ case class SortMergeJoin(
84129

85130
override final def next(): Row = {
86131
if (hasNext) {
87-
// we are using the buffered right rows and run down left iterator
88-
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
89-
rightPosition += 1
90-
if (rightPosition >= rightMatches.size) {
91-
rightPosition = 0
92-
fetchLeft()
93-
if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
94-
stop = false
95-
rightMatches = null
132+
if (bufferedMatches == null || bufferedMatches.size == 0) {
133+
// we just found a row with no join match and we are here to produce a row
134+
// with this row with a standard null row from the other side.
135+
if (continueStreamed) {
136+
val joinedRow = smartJoinRow(streamedElement, bufferedNullRow.copy())
137+
fetchStreamed()
138+
joinedRow
139+
} else {
140+
val joinedRow = smartJoinRow(streamedNullRow.copy(), bufferedElement)
141+
fetchBuffered()
142+
joinedRow
143+
}
144+
} else {
145+
// we are using the buffered right rows and run down left iterator
146+
val joinedRow = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition))
147+
bufferedPosition += 1
148+
if (bufferedPosition >= bufferedMatches.size) {
149+
bufferedPosition = 0
150+
if (joinType != FullOuter || secondStreamedElement == null) {
151+
fetchStreamed()
152+
if (streamedElement == null || keyOrdering.compare(streamedKey, matchKey) != 0) {
153+
stop = false
154+
bufferedMatches = null
155+
}
156+
} else {
157+
// in FullOuter join and the first time we finish the match buffer,
158+
// we still want to generate all rows with streamed null row and buffered
159+
// rows that match the join key but not the conditions.
160+
streamedElement = secondStreamedElement
161+
bufferedMatches = secondBufferedMatches
162+
secondStreamedElement = null
163+
secondBufferedMatches = null
164+
}
96165
}
166+
joinedRow
97167
}
98-
joinedRow
99168
} else {
100169
// no more result
101170
throw new NoSuchElementException
102171
}
103172
}
104173

105-
private def fetchLeft() = {
106-
if (leftIter.hasNext) {
107-
leftElement = leftIter.next()
108-
leftKey = leftKeyGenerator(leftElement)
174+
private def smartJoinRow(streamedRow: Row, bufferedRow: Row): Row = joinType match {
175+
case RightOuter => joinRow(bufferedRow, streamedRow)
176+
case _ => joinRow(streamedRow, bufferedRow)
177+
}
178+
179+
private def fetchStreamed() = {
180+
if (streamedIter.hasNext) {
181+
streamedElement = streamedIter.next()
182+
streamedKey = streamedKeyGenerator(streamedElement)
109183
} else {
110-
leftElement = null
184+
streamedElement = null
111185
}
112186
}
113187

114-
private def fetchRight() = {
115-
if (rightIter.hasNext) {
116-
rightElement = rightIter.next()
117-
rightKey = rightKeyGenerator(rightElement)
188+
private def fetchBuffered() = {
189+
if (bufferedIter.hasNext) {
190+
bufferedElement = bufferedIter.next()
191+
bufferedKey = bufferedKeyGenerator(bufferedElement)
118192
} else {
119-
rightElement = null
193+
bufferedElement = null
120194
}
121195
}
122196

123197
private def initialize() = {
124-
fetchLeft()
125-
fetchRight()
198+
fetchStreamed()
199+
fetchBuffered()
126200
}
127201

128202
/**
129203
* Searches the right iterator for the next rows that have matches in left side, and store
130204
* them in a buffer.
205+
* When this is not a Inner join, we will also return true when we get a row with no match
206+
* on the other side. This search will jump out every time from the same position until
207+
* `next()` is called.
131208
*
132209
* @return true if the search is successful, and false if the right iterator runs out of
133210
* tuples.
134211
*/
135212
private def nextMatchingPair(): Boolean = {
136-
if (!stop && rightElement != null) {
137-
// run both side to get the first match pair
138-
while (!stop && leftElement != null && rightElement != null) {
139-
val comparing = keyOrdering.compare(leftKey, rightKey)
213+
if (!stop && streamedElement != null) {
214+
// step 1: run both side to get the first match pair
215+
while (!stop && streamedElement != null && bufferedElement != null) {
216+
val comparing = keyOrdering.compare(streamedKey, bufferedKey)
140217
// for inner join, we need to filter those null keys
141-
stop = comparing == 0 && !leftKey.anyNull
142-
if (comparing > 0 || rightKey.anyNull) {
143-
fetchRight()
144-
} else if (comparing < 0 || leftKey.anyNull) {
145-
fetchLeft()
218+
stop = comparing == 0 && !streamedKey.anyNull
219+
if (comparing > 0 || bufferedKey.anyNull) {
220+
if (joinType == FullOuter) {
221+
// the join type is full outer and the buffered side has a row with no
222+
// join match, so we have a result row with streamed null with buffered
223+
// side as this row. Then we fetch next buffered element and go back.
224+
continueStreamed = false
225+
return true
226+
} else {
227+
fetchBuffered()
228+
}
229+
} else if (comparing < 0 || streamedKey.anyNull) {
230+
if (joinType == Inner) {
231+
fetchStreamed()
232+
} else {
233+
// the join type is not inner and the streamed side has a row with no
234+
// join match, so we have a result row with this streamed row with buffered
235+
// null row. Then we fetch next streamed element and go back.
236+
continueStreamed = true
237+
return true
238+
}
146239
}
147240
}
148-
rightMatches = new CompactBuffer[Row]()
241+
// step 2: run down the buffered side to put all matched rows in a buffer
242+
bufferedMatches = new CompactBuffer[Row]()
243+
secondBufferedMatches = new CompactBuffer[Row]()
149244
if (stop) {
150245
stop = false
151246
// iterate the right side to buffer all rows that matches
152247
// as the records should be ordered, exit when we meet the first that not match
153-
while (!stop && rightElement != null) {
154-
rightMatches += rightElement
155-
fetchRight()
156-
stop = keyOrdering.compare(leftKey, rightKey) != 0
248+
while (!stop) {
249+
if (boundCondition(joinRow(streamedElement, bufferedElement))) {
250+
bufferedMatches += bufferedElement
251+
} else if (joinType == FullOuter) {
252+
bufferedMatches += bufferedNullRow.copy()
253+
secondBufferedMatches += bufferedElement
254+
}
255+
fetchBuffered()
256+
stop =
257+
keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null
258+
}
259+
if (bufferedMatches.size == 0 && joinType != Inner) {
260+
bufferedMatches += bufferedNullRow.copy()
261+
}
262+
if (bufferedMatches.size > 0) {
263+
bufferedPosition = 0
264+
matchKey = streamedKey
265+
// secondBufferedMatches.size cannot be larger than bufferedMatches
266+
if (secondBufferedMatches.size > 0) {
267+
secondStreamedElement = streamedNullRow.copy()
268+
}
269+
}
270+
}
271+
}
272+
// `stop` is false iff left or right has finished iteration in step 1.
273+
// if we get into step 2, `stop` cannot be false.
274+
if (!stop && (bufferedMatches == null || bufferedMatches.size == 0)) {
275+
if (streamedElement == null && bufferedElement != null) {
276+
// streamedElement == null but bufferedElement != null
277+
if (joinType == FullOuter) {
278+
continueStreamed = false
279+
return true
157280
}
158-
if (rightMatches.size > 0) {
159-
rightPosition = 0
160-
matchKey = leftKey
281+
} else if (streamedElement != null && bufferedElement == null) {
282+
// bufferedElement == null but streamedElement != null
283+
if (joinType != Inner) {
284+
continueStreamed = true
285+
return true
161286
}
162287
}
163288
}
164-
rightMatches != null && rightMatches.size > 0
289+
bufferedMatches != null && bufferedMatches.size > 0
165290
}
166291
}
167292
}

0 commit comments

Comments
 (0)