Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Change number-of-rows fallback behavior to restrict fast map capacity
  • Loading branch information
c21 committed Apr 22, 2021
commit 67d4cd76c4de27b69acb8edc1ea37972a2de67aa
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,6 @@ case class HashAggregateExec(
// all the mode of aggregate expressions
private val modes = aggregateExpressions.map(_.mode).distinct

// This is for testing final aggregate with number-of-rows-based fall back as specified in
// `testFallbackStartsAt`. In this scenario, there might be same keys exist in both fast and
// regular hash map. So the aggregation buffers from both maps need to be merged together
// to avoid correctness issue.
//
// This scenario only happens in unit test with number-of-rows-based fall back.
// There should not be same keys in both maps with size-based fall back in production.
private val isTestFinalAggregateWithFallback: Boolean = testFallbackStartsAt.isDefined &&
(modes.contains(Final) || modes.contains(Complete))

override def usedInputs: AttributeSet = inputSet

override def supportCodegen: Boolean = {
Expand Down Expand Up @@ -547,34 +537,6 @@ case class HashAggregateExec(
}
}

/**
* Called by generated Java class to finish merge the fast hash map into regular map.
* This is used for testing final aggregate only.
*/
def mergeFastHashMapForTest(
fastHashMapRowIter: KVIterator[UnsafeRow, UnsafeRow],
regularHashMap: UnsafeFixedWidthAggregationMap): Unit = {

// Create a MutableProjection to merge the buffers of same key together
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = MutableProjection.create(
mergeExpr,
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes))
val joinedRow = new JoinedRow()

while (fastHashMapRowIter.next()) {
val key = fastHashMapRowIter.getKey
val fastMapBuffer = fastHashMapRowIter.getValue
val regularMapBuffer = regularHashMap.getAggregationBufferFromUnsafeRow(key)

// Merge the aggregation buffer of fast hash map, into the buffer with same key of
// regular map
mergeProjection.target(regularMapBuffer)
mergeProjection(joinedRow(regularMapBuffer, fastMapBuffer))
}
fastHashMapRowIter.close()
}

/**
* Generate the code for output.
* @return function name for the result code.
Expand Down Expand Up @@ -721,7 +683,18 @@ case class HashAggregateExec(
} else if (sqlContext.conf.enableVectorizedHashMap) {
logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.")
}
val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit
val bitMaxCapacity = testFallbackStartsAt match {
case Some((fastMapCounter, _)) =>
// In testing, with fall back counter of fast hash map (`fastMapCounter`), set the max bit
// of map to be no more than log2(`fastMapCounter`). This helps control the number of keys
// in map to mimic fall back.
if (fastMapCounter <= 1) {
0
} else {
(math.log10(fastMapCounter) / math.log10(2)).floor.toInt
}
case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit
}

val thisPlan = ctx.addReferenceObj("plan", this)

Expand Down Expand Up @@ -797,13 +770,8 @@ case class HashAggregateExec(
val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" +
s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);"
val finishHashMap = if (isFastHashMapEnabled) {
val finishFastHashMap = if (isTestFinalAggregateWithFallback) {
s"$thisPlan.mergeFastHashMapForTest($fastHashMapTerm.rowIterator(), $hashMapTerm);"
} else {
s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"
}
s"""
|$finishFastHashMap
|$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();
|$finishRegularHashMap
""".stripMargin
} else {
Expand All @@ -826,7 +794,7 @@ case class HashAggregateExec(
val limitNotReachedCondition = limitNotReachedCond
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the limit early termination for first level map as well. This is needed to fix test failure SQLMetricsSuite.SPARK-25497: LIMIT within whole stage codegen should not consume all the inputs in https://github.com/c21/spark/runs/2386397792?check_suite_focus=true. And this is good to have anyway.


def outputFromFastHashMap: String = {
if (isFastHashMapEnabled && !isTestFinalAggregateWithFallback) {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
outputFromVectorizedMap
} else {
Expand Down Expand Up @@ -931,13 +899,11 @@ case class HashAggregateExec(
}
}

val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
incCounter) = if (testFallbackStartsAt.isDefined) {
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
(s"$countTerm < ${testFallbackStartsAt.get._1}",
s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
} else {
("true", "true", "", "")
val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match {
case Some((_, regularMapCounter)) =>
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
(s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;")
case _ => ("true", "", "")
}

val oomeClassName = classOf[SparkOutOfMemoryError].getName
Expand Down Expand Up @@ -977,12 +943,10 @@ case class HashAggregateExec(
// If fast hash map is on, we first generate code to probe and update the fast hash map.
// If the probe is successful the corresponding fast row buffer will hold the mutable row.
s"""
|if ($checkFallbackForGeneratedHashMap) {
| ${fastRowKeys.map(_.code).mkString("\n")}
| if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
| $fastRowBuffer = $fastHashMapTerm.findOrInsert(
| ${fastRowKeys.map(_.value).mkString(", ")});
| }
|${fastRowKeys.map(_.code).mkString("\n")}
|if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
| $fastRowBuffer = $fastHashMapTerm.findOrInsert(
| ${fastRowKeys.map(_.value).mkString(", ")});
|}
|// Cannot find the key in fast hash map, try regular hash map.
|if ($fastRowBuffer == null) {
Expand Down