Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ case class Grouping(child: Expression) extends Expression with Unevaluable
Examples:
> SELECT name, _FUNC_(), sum(age), avg(height) FROM VALUES (2, 'Alice', 165), (5, 'Bob', 180) people(age, name, height) GROUP BY cube(name, height);
Alice 0 2 165.0
Bob 0 5 180.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is needed as hash aggregation output order is changed, and it causes ExpressionInfoSuite.check outputs of expression examples test failure in https://github.com/c21/spark/runs/2386397792?check_suite_focus=true .

Alice 1 2 165.0
NULL 3 7 172.5
Bob 0 5 180.0
Bob 1 5 180.0
NULL 2 2 165.0
NULL 2 5 180.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit._
import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.memory.SparkOutOfMemoryError
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -128,6 +128,16 @@ case class HashAggregateExec(
// all the mode of aggregate expressions
private val modes = aggregateExpressions.map(_.mode).distinct

// This is for testing final aggregate with number-of-rows-based fall back as specified in
// `testFallbackStartsAt`. In this scenario, there might be same keys exist in both fast and
// regular hash map. So the aggregation buffers from both maps need to be merged together
// to avoid correctness issue.
//
// This scenario only happens in unit test with number-of-rows-based fall back.
// There should not be same keys in both maps with size-based fall back in production.
private val isTestFinalAggregateWithFallback: Boolean = testFallbackStartsAt.isDefined &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of merging the hash maps, shall we fix the number-of-rows-based fallback to make it similar to the size-based fallback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - I was thinking the same way too. I found it's quite hard to fix the fallback logic. I tried the approach to add a find(key): Boolean method in generated first level map, and to first check if key already exists in first level map. But I found other case like the key can be put into second level map, later added to first level map as well (fallback row counter reset to 0 case).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't touched this part of the code for a while, can you briefly introduce how size-based fallback work?

Copy link
Contributor Author

@c21 c21 Apr 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - sure. This is how number-of-rows-based fallback works.

With an internal config spark.sql.TungstenAggregate.testFallbackStartsAt, we can set (1). when to fallback from first level hash map to second level hash map, and (2). when to fallback from second level hash map to sort.

Suppose spark.sql.TungstenAggregate.testFallbackStartsAt = "2, 3".

Then the generated code per input row (aggregate the row into hash map) looks like:

UnsafeRow agg_buffer = null;

if (counter < 2) {
  // 1st level hash map
  agg_buffer = fastHashMap.findOrInsert(key);
}
if (agg_buffer == null) {
  // generated. code for key in unsafe row format
  ...
  if (counter < 3) {
    // 2nd level hash map
    agg_buffer = regularHashMap.getAggregationBufferFromUnsafeRow(key_in_unsafe_row, ...);
  }
  if (agg_buffer == null) {
    // sort-based fallback
    regularHashMap.destructAndCreateExternalSorter();
    ...
    counter = 0;
  }
}
counter += 1;

Example generated code is Line 187-232 in https://gist.github.com/c21/d0f704c0a33c24ec05387ff4df438bff .

I tried to add a method fastHashMap.find(key): boolean, and change code like this:

...
if (fastHashMap.find(key) || counter < 2) {
  // 1st level hash map
  agg_buffer = fastHashMap.findOrInsert(key);
}
...

But I later found the case as I mentioned above:

  1. key(a) is inserted into second level hash map (when counter exceeds 1st threshold)
  2. sort-based fallback happens, and counter is reset to 0 (when counter exceeds 2nd threshold)
  3. key(a) is not in first level hash map, and counter does not exceed 1st threshold, the key(a) is inserted into first level hash map as well by mistake.

We can further add code like this:

if ((fastHashMap.find(key) && !regularHashMap.find(key_in_unsafe_row)) || counter < 2) {
  // 1st level hash map
  agg_buffer = fastHashMap.findOrInsert(key);
}

But it introduces more ad-hoc change and looks pretty ugly with a lot of code needs to be moved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - sorry, I overlooked your question, you are asking how size-based fallback works.

Size-based fallback works as:

  1. try to insert into 1st level hash map, and fallback to 2nd level hash map when no space in the required memory page (RowBasedKeyValueBatch ) - https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala#L165-L166 .
  2. try to insert into 2nd level hash map, and fallback to sort-based when no space in UnsafeFixedWidthAggregationMap - https://github.com/apache/spark/blob/master/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java#L148-L150 .
  3. the 2nd level hash map will be sorted and spilled and another new 2nd level hash map will be created. The 1st level hash map cannot be spilled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that the major issue is we use a single counter to control both the fast and regular hash map fallback. My first thought is to add a dedicated counter for the fast hash map fallback, then I realized that the fast hash map has a capacity property. Can we simply set the capacity to testFallbackStartsAt._1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - maybe I am missing something but not sure how these two solutions fix the problem.

  1. dedicated counters for two maps
if (counter1 < 2) {
  // 1st level hash map
  agg_buffer = fastHashMap.findOrInsert(key);
}
if (agg_buffer == null) {
  // generated. code for key in unsafe row format
  ...
  if (counter2 < 3) {
    // 2nd level hash map
    agg_buffer = regularHashMap.getAggregationBufferFromUnsafeRow(key_in_unsafe_row, ...);
  }
  if (agg_buffer == null) {
    // sort-based fallback
    regularHashMap.destructAndCreateExternalSorter();
    ...
    counter2 = 0;
  }
}
counter1 += 1;
counter2 += 1;

Counter example:

1. key_a is inserted into 1st level map (counter1 = 0)
2. a couple of keys are inserted into 1st level map (count1 =2)
3. key_a is inserted into 2nd level map (count1 = 2, count2 = 2)
  1. set 1st level map bitMaxCapacity to be log2(testFallbackStartsAt._1).
if (counter < 2) {
  // 1st level hash map
  agg_buffer = fastHashMap.findOrInsert(key);
}
if (agg_buffer == null) {
  // generated. code for key in unsafe row format
  ...
  if (counter < 3) {
    // 2nd level hash map
    agg_buffer = regularHashMap.getAggregationBufferFromUnsafeRow(key_in_unsafe_row, ...);
  }
  if (agg_buffer == null) {
    // sort-based fallback
    regularHashMap.destructAndCreateExternalSorter();
    ...
    counter = 0;
  }
}
counter += 1;

Counter example:

1. key_a is inserted into 1st level map (counter = 0)
2. a couple of NULL keys are inserted into 2nd level map (count = 2). Note: 1st level map does not support NULL key.
3. key_a is inserted into 2nd level map (count1 = 2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea is to simulate the size-based fallback: "no space" -> "reach the capacity/limit"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - updated per offline discussion. Changed the first level fallback by restricting first level map capacity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just FYI - I updated the generated code in PR description for checking if needed.

(modes.contains(Final) || modes.contains(Complete))

override def usedInputs: AttributeSet = inputSet

override def supportCodegen: Boolean = {
Expand Down Expand Up @@ -435,8 +445,8 @@ case class HashAggregateExec(
)
}

def getTaskMemoryManager(): TaskMemoryManager = {
TaskContext.get().taskMemoryManager()
def getTaskContext(): TaskContext = {
TaskContext.get()
}

def getEmptyAggregationBuffer(): InternalRow = {
Expand Down Expand Up @@ -537,6 +547,34 @@ case class HashAggregateExec(
}
}

/**
* Called by generated Java class to finish merge the fast hash map into regular map.
* This is used for testing final aggregate only.
*/
def mergeFastHashMapForTest(
fastHashMapRowIter: KVIterator[UnsafeRow, UnsafeRow],
regularHashMap: UnsafeFixedWidthAggregationMap): Unit = {

// Create a MutableProjection to merge the buffers of same key together
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = MutableProjection.create(
mergeExpr,
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes))
val joinedRow = new JoinedRow()

while (fastHashMapRowIter.next()) {
val key = fastHashMapRowIter.getKey
val fastMapBuffer = fastHashMapRowIter.getValue
val regularMapBuffer = regularHashMap.getAggregationBufferFromUnsafeRow(key)

// Merge the aggregation buffer of fast hash map, into the buffer with same key of
// regular map
mergeProjection.target(regularMapBuffer)
mergeProjection(joinedRow(regularMapBuffer, fastMapBuffer))
}
fastHashMapRowIter.close()
}

/**
* Generate the code for output.
* @return function name for the result code.
Expand Down Expand Up @@ -647,7 +685,7 @@ case class HashAggregateExec(
(groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] ||
f.dataType.isInstanceOf[CalendarIntervalType]) &&
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
bufferSchema.nonEmpty

// For vectorized hash map, We do not support byte array based decimal type for aggregate values
// as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place
Expand All @@ -663,7 +701,7 @@ case class HashAggregateExec(

private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {
if (!checkIfFastHashMapSupported(ctx)) {
if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) {
Copy link
Contributor

@cloud-fan cloud-fan Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last question: can we search the commit history and figure out why we didn't enable the fast hash map in the final aggregate? It seems we did it on purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - I was wondering at first place before making this PR as well. The decision to only support partial aggregate is made when the first level hash map was introduced (#12345 and #14176), and never changed afterwards. I checked with @sameeragarwal before making this PR. He told me there is no fundamental reason to not support final aggregate.

Just for documentation, I asked him why we don't support nested type (array/map/struct) as key type for fast hash map. He told me the reason was the size of keys might be too large for long array/map/struct, so the size of fast hash map may not fit in cache and lose the benefit.

if (!Utils.isTesting) {
logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but"
+ " current version of codegened fast hashmap does not support this aggregate.")
}
Expand All @@ -687,40 +725,59 @@ case class HashAggregateExec(

val thisPlan = ctx.addReferenceObj("plan", this)

// Create a name for the iterator from the fast hash map, and the code to create fast hash map.
val (iterTermForFastHashMap, createFastHashMap) = if (isFastHashMapEnabled) {
// Generates the fast hash map class and creates the fast hash map term.
val fastHashMapClassName = ctx.freshName("FastHashMap")
if (isVectorizedHashMapEnabled) {
val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
ctx.addInnerClass(generatedMap)

// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(
fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
val iter = ctx.addMutableState(
"java.util.Iterator<InternalRow>",
"vectorizedFastHashMapIter",
forceInline = true)
val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
(iter, create)
} else {
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
ctx.addInnerClass(generatedMap)

// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(
fastHashMapClassName, "fastHashMap", forceInline = true)
val iter = ctx.addMutableState(
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
"fastHashMapIter", forceInline = true)
val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());"
(iter, create)
}
} else ("", "")
// Create a name for the iterator from the fast hash map, the code to create
// and add hook to close fast hash map.
val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap) =
if (isFastHashMapEnabled) {
// Generates the fast hash map class and creates the fast hash map term.
val fastHashMapClassName = ctx.freshName("FastHashMap")
val (iter, create) = if (isVectorizedHashMapEnabled) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is changed compared to previous code.

val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
ctx.addInnerClass(generatedMap)

// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(
fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
val iter = ctx.addMutableState(
"java.util.Iterator<InternalRow>",
"vectorizedFastHashMapIter",
forceInline = true)
val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
(iter, create)
} else {
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
ctx.addInnerClass(generatedMap)

// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(
fastHashMapClassName, "fastHashMap", forceInline = true)
val iter = ctx.addMutableState(
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
"fastHashMapIter", forceInline = true)
val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
s"$thisPlan.getTaskContext().taskMemoryManager(), " +
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is changed compared to previous code.

s"$thisPlan.getEmptyAggregationBuffer());"
(iter, create)
}

// Generates the code to register a cleanup task with TaskContext to ensure that memory
// is guaranteed to be freed at the end of the task. This is necessary to avoid memory
// leaks in when the downstream operator does not fully consume the aggregation map's
// output (e.g. aggregate followed by limit).
val hookToCloseFastHashMap =
s"""
|$thisPlan.getTaskContext().addTaskCompletionListener(
| new org.apache.spark.util.TaskCompletionListener() {
| @Override
| public void onTaskCompletion(org.apache.spark.TaskContext context) {
| $fastHashMapTerm.close();
| }
|});
""".stripMargin
(iter, create, hookToCloseFastHashMap)
Copy link
Contributor Author

@c21 c21 Apr 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only real change inside L728-779. Add a new hookToCloseFastHashMap here to clean up fast hash map. The other code is not changed except indentation. Not sure why github highlights so many change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the code in a less diff way?

...
val hookToCloseFastHashMap = if (isFastHashMapEnabled) {
 ...
} else ""

Copy link
Contributor

@cloud-fan cloud-fan Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we put the logic in the base HashMapGenerator as a method, and call the method in both the vectorized and row-based fast hash map generator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - updated to go with #32242 (comment) . Thanks.

} else ("", "", "")

// Create a name for the iterator from the regular hash map.
// Inline mutable state since not many aggregation operations in a task
Expand All @@ -740,8 +797,13 @@ case class HashAggregateExec(
val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" +
s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);"
val finishHashMap = if (isFastHashMapEnabled) {
val finishFastHashMap = if (isTestFinalAggregateWithFallback) {
s"$thisPlan.mergeFastHashMapForTest($fastHashMapTerm.rowIterator(), $hashMapTerm);"
} else {
s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"
}
s"""
|$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();
|$finishFastHashMap
|$finishRegularHashMap
""".stripMargin
} else {
Expand All @@ -761,8 +823,10 @@ case class HashAggregateExec(
val bufferTerm = ctx.freshName("aggBuffer")
val outputFunc = generateResultFunction(ctx)

val limitNotReachedCondition = limitNotReachedCond
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the limit early termination for first level map as well. This is needed to fix test failure SQLMetricsSuite.SPARK-25497: LIMIT within whole stage codegen should not consume all the inputs in https://github.com/c21/spark/runs/2386397792?check_suite_focus=true. And this is good to have anyway.


def outputFromFastHashMap: String = {
if (isFastHashMapEnabled) {
if (isFastHashMapEnabled && !isTestFinalAggregateWithFallback) {
if (isVectorizedHashMapEnabled) {
outputFromVectorizedMap
} else {
Expand All @@ -773,7 +837,7 @@ case class HashAggregateExec(

def outputFromRowBasedMap: String = {
s"""
|while ($iterTermForFastHashMap.next()) {
|while ($limitNotReachedCondition $iterTermForFastHashMap.next()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
| $outputFunc($keyTerm, $bufferTerm);
Expand All @@ -798,7 +862,7 @@ case class HashAggregateExec(
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
})
s"""
|while ($iterTermForFastHashMap.hasNext()) {
|while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) {
| InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
| ${generateKeyRow.code}
| ${generateBufferRow.code}
Expand All @@ -813,7 +877,7 @@ case class HashAggregateExec(

def outputFromRegularHashMap: String = {
s"""
|while ($limitNotReachedCond $iterTerm.next()) {
|while ($limitNotReachedCondition $iterTerm.next()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
| $outputFunc($keyTerm, $bufferTerm);
Expand All @@ -832,6 +896,7 @@ case class HashAggregateExec(
|if (!$initAgg) {
| $initAgg = true;
| $createFastHashMap
| $addHookToCloseFastHashMap
| $hashMapTerm = $thisPlan.createHashMap();
| long $beforeAgg = System.nanoTime();
| $doAggFuncName();
Expand Down