Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address comments
  • Loading branch information
Andrew Or committed Aug 10, 2015
commit b4d3633b256de6d981ef7fd2e62afa5490323682
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ case class TungstenAggregate(
// We're not using the underlying map, so we just can free it here
aggregationIterator.free()
if (groupingExpressions.isEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator[UnsafeRow]()
}
} else {
Expand All @@ -104,10 +104,9 @@ case class TungstenAggregate(

// Note: we need to set up the iterator in each partition before computing the
// parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
val parentPartition = child.execute()
val resultRdd = {
new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
parentPartition, preparePartition, executePartition, preservesPartitioning = true)
child.execute(), preparePartition, executePartition, preservesPartitioning = true)
}
resultRdd.asInstanceOf[RDD[InternalRow]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just return resultRdd? Seems we do not need to cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually result RDD is of type RDD[UnsafeRow]. Since RDDs are not covariant I think we do need the cast.

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class TungstenAggregationIterator(
extends Iterator[UnsafeRow] with Logging {

// The parent partition iterator, to be initialized later in `start`
private[this] var inputIter: Iterator[InternalRow] = Iterator[InternalRow]()
private[this] var inputIter: Iterator[InternalRow] = null

///////////////////////////////////////////////////////////////////////////
// Part 1: Initializing aggregate functions.
Expand Down Expand Up @@ -334,7 +334,7 @@ class TungstenAggregationIterator(
// This is the hash map used for hash-based aggregation. It is backed by an
// UnsafeFixedWidthAggregationMap and it is used to store
// all groups and their corresponding aggregation buffers for hash-based aggregation.
private[aggregate] val hashMap = new UnsafeFixedWidthAggregationMap(
private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
initialAggregationBuffer,
StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
Expand All @@ -345,11 +345,15 @@ class TungstenAggregationIterator(
false // disable tracking of performance metrics
)

// Exposed for testing
private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap

// The function used to read and process input rows. When processing input rows,
// it first uses hash-based aggregation by putting groups and their buffers in
// hashMap. If we could not allocate more memory for the map, we switch to
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
val groupingKey = groupProjection.apply(newInput)
Expand All @@ -368,6 +372,7 @@ class TungstenAggregationIterator(
// that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
// been processed.
private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
var i = 0
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
Expand Down Expand Up @@ -407,6 +412,7 @@ class TungstenAggregationIterator(
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter()
Expand All @@ -426,8 +432,9 @@ class TungstenAggregationIterator(
case _ => false
}

// Note: we spill the sorter's contents immediately after creating it. Therefore, we must
// insert something into the sorter here to ensure that we acquire at least a page of memory.
// Note: Since we spill the sorter's contents immediately after creating it, we must insert
// something into the sorter here to ensure that we acquire at least a page of memory.
// This is done through `externalSorter.insertKV`, which will trigger the page allocation.
// Otherwise, children operators may steal the window of opportunity and starve our sorter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we explicitly say that externalSorter.insertKV(firstKey, buffer) will trigger the page allocation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (needsProcess) {
Expand Down Expand Up @@ -684,7 +691,7 @@ class TungstenAggregationIterator(
*/
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
assert(groupingExpressions.isEmpty)
assert(!inputIter.hasNext)
assert(inputIter == null)
generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with LocalSparkCont
}
iter = new TungstenAggregationIterator(
Seq.empty, Seq.empty, Seq.empty, 0, Seq.empty, newMutableProjection, Seq.empty, None)
val numPages = iter.hashMap.getNumDataPages
val numPages = iter.getHashMap.getNumDataPages
assert(numPages === 1)
} finally {
// Clean up
Expand Down