-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-35141][SQL] Support two level of hash maps for final hash aggregation #32242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
539720d
c9d09be
917e7bb
67d4cd7
965a35c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit._ | |
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} | ||
| import org.apache.spark.memory.SparkOutOfMemoryError | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
|
|
@@ -128,6 +128,16 @@ case class HashAggregateExec( | |
| // all the mode of aggregate expressions | ||
| private val modes = aggregateExpressions.map(_.mode).distinct | ||
|
|
||
| // This is for testing final aggregate with number-of-rows-based fall back as specified in | ||
| // `testFallbackStartsAt`. In this scenario, there might be same keys exist in both fast and | ||
| // regular hash map. So the aggregation buffers from both maps need to be merged together | ||
| // to avoid correctness issue. | ||
| // | ||
| // This scenario only happens in unit test with number-of-rows-based fall back. | ||
| // There should not be same keys in both maps with size-based fall back in production. | ||
| private val isTestFinalAggregateWithFallback: Boolean = testFallbackStartsAt.isDefined && | ||
|
||
| (modes.contains(Final) || modes.contains(Complete)) | ||
|
|
||
| override def usedInputs: AttributeSet = inputSet | ||
|
|
||
| override def supportCodegen: Boolean = { | ||
|
|
@@ -435,8 +445,8 @@ case class HashAggregateExec( | |
| ) | ||
| } | ||
|
|
||
| def getTaskMemoryManager(): TaskMemoryManager = { | ||
| TaskContext.get().taskMemoryManager() | ||
| def getTaskContext(): TaskContext = { | ||
| TaskContext.get() | ||
| } | ||
|
|
||
| def getEmptyAggregationBuffer(): InternalRow = { | ||
|
|
@@ -537,6 +547,34 @@ case class HashAggregateExec( | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Called by generated Java class to finish merge the fast hash map into regular map. | ||
| * This is used for testing final aggregate only. | ||
| */ | ||
| def mergeFastHashMapForTest( | ||
| fastHashMapRowIter: KVIterator[UnsafeRow, UnsafeRow], | ||
| regularHashMap: UnsafeFixedWidthAggregationMap): Unit = { | ||
|
|
||
| // Create a MutableProjection to merge the buffers of same key together | ||
| val mergeExpr = declFunctions.flatMap(_.mergeExpressions) | ||
| val mergeProjection = MutableProjection.create( | ||
| mergeExpr, | ||
| aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes)) | ||
| val joinedRow = new JoinedRow() | ||
|
|
||
| while (fastHashMapRowIter.next()) { | ||
| val key = fastHashMapRowIter.getKey | ||
| val fastMapBuffer = fastHashMapRowIter.getValue | ||
| val regularMapBuffer = regularHashMap.getAggregationBufferFromUnsafeRow(key) | ||
|
|
||
| // Merge the aggregation buffer of fast hash map, into the buffer with same key of | ||
| // regular map | ||
| mergeProjection.target(regularMapBuffer) | ||
| mergeProjection(joinedRow(regularMapBuffer, fastMapBuffer)) | ||
| } | ||
| fastHashMapRowIter.close() | ||
| } | ||
|
|
||
| /** | ||
| * Generate the code for output. | ||
| * @return function name for the result code. | ||
|
|
@@ -647,7 +685,7 @@ case class HashAggregateExec( | |
| (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || | ||
| f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] || | ||
| f.dataType.isInstanceOf[CalendarIntervalType]) && | ||
| bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) | ||
| bufferSchema.nonEmpty | ||
|
|
||
| // For vectorized hash map, We do not support byte array based decimal type for aggregate values | ||
| // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place | ||
|
|
@@ -663,7 +701,7 @@ case class HashAggregateExec( | |
|
|
||
| private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { | ||
| if (!checkIfFastHashMapSupported(ctx)) { | ||
| if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. last question: can we search the commit history and figure out why we didn't enable the fast hash map in the final aggregate? It seems we did it on purpose.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan - I was wondering at first place before making this PR as well. The decision to only support partial aggregate is made when the first level hash map was introduced (#12345 and #14176), and never changed afterwards. I checked with @sameeragarwal before making this PR. He told me there is no fundamental reason to not support final aggregate. Just for documentation, I asked him why we don't support nested type (array/map/struct) as key type for fast hash map. He told me the reason was the size of keys might be too large for long array/map/struct, so the size of fast hash map may not fit in cache and lose the benefit. |
||
| if (!Utils.isTesting) { | ||
| logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but" | ||
| + " current version of codegened fast hashmap does not support this aggregate.") | ||
| } | ||
|
|
@@ -687,40 +725,59 @@ case class HashAggregateExec( | |
|
|
||
| val thisPlan = ctx.addReferenceObj("plan", this) | ||
|
|
||
| // Create a name for the iterator from the fast hash map, and the code to create fast hash map. | ||
| val (iterTermForFastHashMap, createFastHashMap) = if (isFastHashMapEnabled) { | ||
| // Generates the fast hash map class and creates the fast hash map term. | ||
| val fastHashMapClassName = ctx.freshName("FastHashMap") | ||
| if (isVectorizedHashMapEnabled) { | ||
| val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, | ||
| fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() | ||
| ctx.addInnerClass(generatedMap) | ||
|
|
||
| // Inline mutable state since not many aggregation operations in a task | ||
| fastHashMapTerm = ctx.addMutableState( | ||
| fastHashMapClassName, "vectorizedFastHashMap", forceInline = true) | ||
| val iter = ctx.addMutableState( | ||
| "java.util.Iterator<InternalRow>", | ||
| "vectorizedFastHashMapIter", | ||
| forceInline = true) | ||
| val create = s"$fastHashMapTerm = new $fastHashMapClassName();" | ||
| (iter, create) | ||
| } else { | ||
| val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, | ||
| fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() | ||
| ctx.addInnerClass(generatedMap) | ||
|
|
||
| // Inline mutable state since not many aggregation operations in a task | ||
| fastHashMapTerm = ctx.addMutableState( | ||
| fastHashMapClassName, "fastHashMap", forceInline = true) | ||
| val iter = ctx.addMutableState( | ||
| "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>", | ||
| "fastHashMapIter", forceInline = true) | ||
| val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + | ||
| s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());" | ||
| (iter, create) | ||
| } | ||
| } else ("", "") | ||
| // Create a name for the iterator from the fast hash map, the code to create | ||
| // and add hook to close fast hash map. | ||
| val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap) = | ||
| if (isFastHashMapEnabled) { | ||
| // Generates the fast hash map class and creates the fast hash map term. | ||
| val fastHashMapClassName = ctx.freshName("FastHashMap") | ||
| val (iter, create) = if (isVectorizedHashMapEnabled) { | ||
|
||
| val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, | ||
| fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() | ||
| ctx.addInnerClass(generatedMap) | ||
|
|
||
| // Inline mutable state since not many aggregation operations in a task | ||
| fastHashMapTerm = ctx.addMutableState( | ||
| fastHashMapClassName, "vectorizedFastHashMap", forceInline = true) | ||
| val iter = ctx.addMutableState( | ||
| "java.util.Iterator<InternalRow>", | ||
| "vectorizedFastHashMapIter", | ||
| forceInline = true) | ||
| val create = s"$fastHashMapTerm = new $fastHashMapClassName();" | ||
| (iter, create) | ||
| } else { | ||
| val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, | ||
| fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() | ||
| ctx.addInnerClass(generatedMap) | ||
|
|
||
| // Inline mutable state since not many aggregation operations in a task | ||
| fastHashMapTerm = ctx.addMutableState( | ||
| fastHashMapClassName, "fastHashMap", forceInline = true) | ||
| val iter = ctx.addMutableState( | ||
| "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>", | ||
| "fastHashMapIter", forceInline = true) | ||
| val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + | ||
| s"$thisPlan.getTaskContext().taskMemoryManager(), " + | ||
|
||
| s"$thisPlan.getEmptyAggregationBuffer());" | ||
| (iter, create) | ||
| } | ||
|
|
||
| // Generates the code to register a cleanup task with TaskContext to ensure that memory | ||
| // is guaranteed to be freed at the end of the task. This is necessary to avoid memory | ||
| // leaks in when the downstream operator does not fully consume the aggregation map's | ||
| // output (e.g. aggregate followed by limit). | ||
| val hookToCloseFastHashMap = | ||
| s""" | ||
| |$thisPlan.getTaskContext().addTaskCompletionListener( | ||
| | new org.apache.spark.util.TaskCompletionListener() { | ||
| | @Override | ||
| | public void onTaskCompletion(org.apache.spark.TaskContext context) { | ||
| | $fastHashMapTerm.close(); | ||
| | } | ||
| |}); | ||
| """.stripMargin | ||
| (iter, create, hookToCloseFastHashMap) | ||
|
||
| } else ("", "", "") | ||
|
|
||
| // Create a name for the iterator from the regular hash map. | ||
| // Inline mutable state since not many aggregation operations in a task | ||
|
|
@@ -740,8 +797,13 @@ case class HashAggregateExec( | |
| val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" + | ||
| s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);" | ||
| val finishHashMap = if (isFastHashMapEnabled) { | ||
| val finishFastHashMap = if (isTestFinalAggregateWithFallback) { | ||
| s"$thisPlan.mergeFastHashMapForTest($fastHashMapTerm.rowIterator(), $hashMapTerm);" | ||
| } else { | ||
| s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();" | ||
| } | ||
| s""" | ||
| |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator(); | ||
| |$finishFastHashMap | ||
| |$finishRegularHashMap | ||
| """.stripMargin | ||
| } else { | ||
|
|
@@ -761,8 +823,10 @@ case class HashAggregateExec( | |
| val bufferTerm = ctx.freshName("aggBuffer") | ||
| val outputFunc = generateResultFunction(ctx) | ||
|
|
||
| val limitNotReachedCondition = limitNotReachedCond | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding the limit early termination for first level map as well. This is needed to fix test failure |
||
|
|
||
| def outputFromFastHashMap: String = { | ||
| if (isFastHashMapEnabled) { | ||
| if (isFastHashMapEnabled && !isTestFinalAggregateWithFallback) { | ||
| if (isVectorizedHashMapEnabled) { | ||
| outputFromVectorizedMap | ||
| } else { | ||
|
|
@@ -773,7 +837,7 @@ case class HashAggregateExec( | |
|
|
||
| def outputFromRowBasedMap: String = { | ||
| s""" | ||
| |while ($iterTermForFastHashMap.next()) { | ||
| |while ($limitNotReachedCondition $iterTermForFastHashMap.next()) { | ||
| | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | ||
| | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | ||
| | $outputFunc($keyTerm, $bufferTerm); | ||
|
|
@@ -798,7 +862,7 @@ case class HashAggregateExec( | |
| BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) | ||
| }) | ||
| s""" | ||
| |while ($iterTermForFastHashMap.hasNext()) { | ||
| |while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) { | ||
| | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ||
| | ${generateKeyRow.code} | ||
| | ${generateBufferRow.code} | ||
|
|
@@ -813,7 +877,7 @@ case class HashAggregateExec( | |
|
|
||
| def outputFromRegularHashMap: String = { | ||
| s""" | ||
| |while ($limitNotReachedCond $iterTerm.next()) { | ||
| |while ($limitNotReachedCondition $iterTerm.next()) { | ||
| | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | ||
| | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | ||
| | $outputFunc($keyTerm, $bufferTerm); | ||
|
|
@@ -832,6 +896,7 @@ case class HashAggregateExec( | |
| |if (!$initAgg) { | ||
| | $initAgg = true; | ||
| | $createFastHashMap | ||
| | $addHookToCloseFastHashMap | ||
| | $hashMapTerm = $thisPlan.createHashMap(); | ||
| | long $beforeAgg = System.nanoTime(); | ||
| | $doAggFuncName(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is needed as hash aggregation output order is changed, and it causes
ExpressionInfoSuite.check outputs of expression examplestest failure in https://github.com/c21/spark/runs/2386397792?check_suite_focus=true .