Skip to content

Commit 4dc723d

Browse files
committed
[SPARK-10100][SQL] Eliminate hash table lookup if there is no grouping key in aggregation.
This improves performance by ~ 20% in one of my local test.
1 parent b762f99 commit 4dc723d

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ case class TungstenAggregate(
104104
} else {
105105
// This is a grouped aggregate and the input iterator is empty,
106106
// so return an empty iterator.
107-
Iterator[UnsafeRow]()
107+
Iterator.empty
108108
}
109109
} else {
110110
aggregationIterator.start(parentIterator)

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -357,18 +357,30 @@ class TungstenAggregationIterator(
357357
// sort-based aggregation (by calling switchToSortBasedAggregation).
358358
private def processInputs(): Unit = {
359359
assert(inputIter != null, "attempted to process input when iterator was null")
360-
while (!sortBased && inputIter.hasNext) {
361-
val newInput = inputIter.next()
362-
numInputRows += 1
363-
val groupingKey = groupProjection.apply(newInput)
360+
if (groupingExpressions.isEmpty) {
361+
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
362+
// Note that it would be better to eliminate the hash map entirely in the future.
363+
val groupingKey = groupProjection.apply(null)
364364
val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
365-
if (buffer == null) {
366-
// buffer == null means that we could not allocate more memory.
367-
// Now, we need to spill the map and switch to sort-based aggregation.
368-
switchToSortBasedAggregation(groupingKey, newInput)
369-
} else {
365+
while (inputIter.hasNext) {
366+
val newInput = inputIter.next()
367+
numInputRows += 1
370368
processRow(buffer, newInput)
371369
}
370+
} else {
371+
while (!sortBased && inputIter.hasNext) {
372+
val newInput = inputIter.next()
373+
numInputRows += 1
374+
val groupingKey = groupProjection.apply(newInput)
375+
val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
376+
if (buffer == null) {
377+
// buffer == null means that we could not allocate more memory.
378+
// Now, we need to spill the map and switch to sort-based aggregation.
379+
switchToSortBasedAggregation(groupingKey, newInput)
380+
} else {
381+
processRow(buffer, newInput)
382+
}
383+
}
372384
}
373385
}
374386

0 commit comments

Comments
 (0)