Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix a bug in GroupedIterator and create unit test for it
  • Loading branch information
cloud-fan committed Oct 28, 2015
commit a8cc6b51f40898e5d29f40def64079a5b6530574
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object GroupedIterator {
keyExpressions: Seq[Expression],
inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
if (input.hasNext) {
new GroupedIterator(input, keyExpressions, inputSchema)
new GroupedIterator(input.buffered, keyExpressions, inputSchema)
} else {
Iterator.empty
}
Expand Down Expand Up @@ -64,7 +64,7 @@ object GroupedIterator {
* @param inputSchema The schema of the rows in the `input` iterator.
*/
class GroupedIterator private(
input: Iterator[InternalRow],
input: BufferedIterator[InternalRow],
groupingExpressions: Seq[Expression],
inputSchema: Seq[Attribute])
extends Iterator[(InternalRow, Iterator[InternalRow])] {
Expand All @@ -83,11 +83,12 @@ class GroupedIterator private(

/** Holds a copy of an input row that is in the current group. */
var currentGroup = currentRow.copy()
var currentIterator: Iterator[InternalRow] = null

assert(keyOrdering.compare(currentGroup, currentRow) == 0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we only use keyOrdering to do equality check, why not just use ==? The currentGroup and currentRow are from the same input, they must be both unsafe or safe, and == for UnsafeRow is faster than keyOrdering.compare.

cc @marmbrus

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the whole row, not just the key. This allows us to do the equality check on the key columns only (which might short circuit) instead of doing a full projection on each row to extract the key columns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry I missed it

var currentIterator = createGroupValuesIterator()

// Return true if we already have the next iterator or fetching a new iterator is successful.
def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator
def hasNext: Boolean = currentIterator.ne(null) || fetchNextGroupIterator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are the same, and I prefer the idiomatic version.


def next(): (InternalRow, Iterator[InternalRow]) = {
assert(hasNext) // Ensure we have fetched the next iterator.
Expand All @@ -96,46 +97,64 @@ class GroupedIterator private(
ret
}

def fetchNextGroupIterator(): Boolean = {
if (currentRow != null || input.hasNext) {
val inputIterator = new Iterator[InternalRow] {
// Return true if we have a row and it is in the current group, or if fetching a new row is
// successful.
def hasNext = {
(currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) ||
fetchNextRowInGroup()
}
private def fetchNextGroupIterator(): Boolean = {
assert(currentIterator eq null)

if (currentRow.eq(null) && input.hasNext) {
currentRow = input.next()
}

if (currentRow eq null) {
// These is no data left, return false.
false
} else {
// Skip to next group.
while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) {
currentRow = input.next()
}

if (keyOrdering.compare(currentGroup, currentRow) == 0) {
// These is no more group. return false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "there" or maybe more clearly "we are no longer in the current group, return false."

false
} else {
// Now the `currentRow` is the first row of next group.
currentGroup = currentRow.copy()
currentIterator = createGroupValuesIterator()
true
}
}
}

private def createGroupValuesIterator(): Iterator[InternalRow] = {
new Iterator[InternalRow] {
def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()

def next(): InternalRow = {
assert(hasNext)
val res = currentRow
currentRow = null
res
}

def fetchNextRowInGroup(): Boolean = {
if (currentRow != null || input.hasNext) {
private def fetchNextRowInGroup(): Boolean = {
assert(currentRow eq null)

if (input.hasNext) {
// The inner iterator should NOT consume the input into next group, here we use `head` to
// peek the next input, to see if we should continue to process it.
if (keyOrdering.compare(currentGroup, input.head) == 0) {
// Next input is in the current group. Continue the inner iterator.
currentRow = input.next()
if (keyOrdering.compare(currentGroup, currentRow) == 0) {
// The row is in the current group. Continue the inner iterator.
true
} else {
// We got a row, but its not in the right group. End this inner iterator and prepare
// for the next group.
currentIterator = null
currentGroup = currentRow.copy()
false
}
true
} else {
// There is no more input so we are done.
// Next input is not in the right group. End this inner iterator.
false
}
}

def next(): InternalRow = {
assert(hasNext) // Ensure we have fetched the next row.
val res = currentRow
currentRow = null
res
} else {
// There is no more data, return false.
false
}
}
currentIterator = inputIterator
true
} else {
false
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package org.apache.spark.sql.execution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add the apache header


import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

test("basic") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
val encoder = RowEncoder(schema)
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0)), schema.toAttributes)

val result = grouped.map {
case (key, data) =>
assert(key.numFields == 1)
key.getInt(0) -> data.map(encoder.fromRow).toSeq
}.toSeq

assert(result ==
1 -> Seq(input(0), input(1)) ::
2 -> Seq(input(2)) :: Nil)
}

test("group by 2 columns") {
val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
val encoder = RowEncoder(schema)

val input = Seq(
Row(1, 2L, "a"),
Row(1, 2L, "b"),
Row(1, 3L, "c"),
Row(2, 1L, "d"),
Row(3, 2L, "e"))

val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

val result = grouped.map {
case (key, data) =>
assert(key.numFields == 2)
(key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
}.toSeq

assert(result ==
(1, 2L, Seq(input(0), input(1))) ::
(1, 3L, Seq(input(2))) ::
(2, 1L, Seq(input(3))) ::
(3, 2L, Seq(input(4))) :: Nil)
}

test("do nothing to the value iterator") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
val encoder = RowEncoder(schema)
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0)), schema.toAttributes)

assert(grouped.length == 2)
}
}