-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11370][SQL] fix a bug in GroupedIterator and create unit test for it #9330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ object GroupedIterator { | |
| keyExpressions: Seq[Expression], | ||
| inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = { | ||
| if (input.hasNext) { | ||
| new GroupedIterator(input, keyExpressions, inputSchema) | ||
| new GroupedIterator(input.buffered, keyExpressions, inputSchema) | ||
| } else { | ||
| Iterator.empty | ||
| } | ||
|
|
@@ -64,7 +64,7 @@ object GroupedIterator { | |
| * @param inputSchema The schema of the rows in the `input` iterator. | ||
| */ | ||
| class GroupedIterator private( | ||
| input: Iterator[InternalRow], | ||
| input: BufferedIterator[InternalRow], | ||
| groupingExpressions: Seq[Expression], | ||
| inputSchema: Seq[Attribute]) | ||
| extends Iterator[(InternalRow, Iterator[InternalRow])] { | ||
|
|
@@ -83,11 +83,12 @@ class GroupedIterator private( | |
|
|
||
| /** Holds a copy of an input row that is in the current group. */ | ||
| var currentGroup = currentRow.copy() | ||
| var currentIterator: Iterator[InternalRow] = null | ||
|
|
||
| assert(keyOrdering.compare(currentGroup, currentRow) == 0) | ||
| var currentIterator = createGroupValuesIterator() | ||
|
|
||
| // Return true if we already have the next iterator or fetching a new iterator is successful. | ||
| def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator | ||
| def hasNext: Boolean = currentIterator.ne(null) || fetchNextGroupIterator | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these are the same, and I prefer the idiomatic version. |
||
|
|
||
| def next(): (InternalRow, Iterator[InternalRow]) = { | ||
| assert(hasNext) // Ensure we have fetched the next iterator. | ||
|
|
@@ -96,46 +97,64 @@ class GroupedIterator private( | |
| ret | ||
| } | ||
|
|
||
| def fetchNextGroupIterator(): Boolean = { | ||
| if (currentRow != null || input.hasNext) { | ||
| val inputIterator = new Iterator[InternalRow] { | ||
| // Return true if we have a row and it is in the current group, or if fetching a new row is | ||
| // successful. | ||
| def hasNext = { | ||
| (currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) || | ||
| fetchNextRowInGroup() | ||
| } | ||
| private def fetchNextGroupIterator(): Boolean = { | ||
| assert(currentIterator eq null) | ||
|
|
||
| if (currentRow.eq(null) && input.hasNext) { | ||
| currentRow = input.next() | ||
| } | ||
|
|
||
| if (currentRow eq null) { | ||
| // These is no data left, return false. | ||
| false | ||
| } else { | ||
| // Skip to next group. | ||
| while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) { | ||
| currentRow = input.next() | ||
| } | ||
|
|
||
| if (keyOrdering.compare(currentGroup, currentRow) == 0) { | ||
| // These is no more group. return false. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "there" or maybe more clearly "we are no longer in the current group, return false." |
||
| false | ||
| } else { | ||
| // Now the `currentRow` is the first row of next group. | ||
| currentGroup = currentRow.copy() | ||
| currentIterator = createGroupValuesIterator() | ||
| true | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def createGroupValuesIterator(): Iterator[InternalRow] = { | ||
| new Iterator[InternalRow] { | ||
| def hasNext: Boolean = currentRow != null || fetchNextRowInGroup() | ||
|
|
||
| def next(): InternalRow = { | ||
| assert(hasNext) | ||
| val res = currentRow | ||
| currentRow = null | ||
| res | ||
| } | ||
|
|
||
| def fetchNextRowInGroup(): Boolean = { | ||
| if (currentRow != null || input.hasNext) { | ||
| private def fetchNextRowInGroup(): Boolean = { | ||
| assert(currentRow eq null) | ||
|
|
||
| if (input.hasNext) { | ||
| // The inner iterator should NOT consume the input into next group, here we use `head` to | ||
| // peek the next input, to see if we should continue to process it. | ||
| if (keyOrdering.compare(currentGroup, input.head) == 0) { | ||
| // Next input is in the current group. Continue the inner iterator. | ||
| currentRow = input.next() | ||
| if (keyOrdering.compare(currentGroup, currentRow) == 0) { | ||
| // The row is in the current group. Continue the inner iterator. | ||
| true | ||
| } else { | ||
| // We got a row, but its not in the right group. End this inner iterator and prepare | ||
| // for the next group. | ||
| currentIterator = null | ||
| currentGroup = currentRow.copy() | ||
| false | ||
| } | ||
| true | ||
| } else { | ||
| // There is no more input so we are done. | ||
| // Next input is not in the right group. End this inner iterator. | ||
| false | ||
| } | ||
| } | ||
|
|
||
| def next(): InternalRow = { | ||
| assert(hasNext) // Ensure we have fetched the next row. | ||
| val res = currentRow | ||
| currentRow = null | ||
| res | ||
| } else { | ||
| // There is no more data, return false. | ||
| false | ||
| } | ||
| } | ||
| currentIterator = inputIterator | ||
| true | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| package org.apache.spark.sql.execution | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to add the apache header |
||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.encoders.RowEncoder | ||
| import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType} | ||
|
|
||
| class GroupedIteratorSuite extends SparkFunSuite { | ||
|
|
||
| test("basic") { | ||
| val schema = new StructType().add("i", IntegerType).add("s", StringType) | ||
| val encoder = RowEncoder(schema) | ||
| val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) | ||
| val grouped = GroupedIterator(input.iterator.map(encoder.toRow), | ||
| Seq('i.int.at(0)), schema.toAttributes) | ||
|
|
||
| val result = grouped.map { | ||
| case (key, data) => | ||
| assert(key.numFields == 1) | ||
| key.getInt(0) -> data.map(encoder.fromRow).toSeq | ||
| }.toSeq | ||
|
|
||
| assert(result == | ||
| 1 -> Seq(input(0), input(1)) :: | ||
| 2 -> Seq(input(2)) :: Nil) | ||
| } | ||
|
|
||
| test("group by 2 columns") { | ||
| val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) | ||
| val encoder = RowEncoder(schema) | ||
|
|
||
| val input = Seq( | ||
| Row(1, 2L, "a"), | ||
| Row(1, 2L, "b"), | ||
| Row(1, 3L, "c"), | ||
| Row(2, 1L, "d"), | ||
| Row(3, 2L, "e")) | ||
|
|
||
| val grouped = GroupedIterator(input.iterator.map(encoder.toRow), | ||
| Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) | ||
|
|
||
| val result = grouped.map { | ||
| case (key, data) => | ||
| assert(key.numFields == 2) | ||
| (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq) | ||
| }.toSeq | ||
|
|
||
| assert(result == | ||
| (1, 2L, Seq(input(0), input(1))) :: | ||
| (1, 3L, Seq(input(2))) :: | ||
| (2, 1L, Seq(input(3))) :: | ||
| (3, 2L, Seq(input(4))) :: Nil) | ||
| } | ||
|
|
||
| test("do nothing to the value iterator") { | ||
| val schema = new StructType().add("i", IntegerType).add("s", StringType) | ||
| val encoder = RowEncoder(schema) | ||
| val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) | ||
| val grouped = GroupedIterator(input.iterator.map(encoder.toRow), | ||
| Seq('i.int.at(0)), schema.toAttributes) | ||
|
|
||
| assert(grouped.length == 2) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we only use
keyOrderingto do equality check, why not just use==? ThecurrentGroupandcurrentRoware from the same input, they must be both unsafe or safe, and==forUnsafeRowis faster thankeyOrdering.compare.cc @marmbrus
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the whole row, not just the key. This allows us to do the equality check on the key columns only (which might short circuit) instead of doing a full projection on each row to extract the key columns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry I missed it