-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25497][SQL] Limit operation within whole stage codegen should not consume all the inputs #22630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-25497][SQL] Limit operation within whole stage codegen should not consume all the inputs #22630
Changes from 6 commits
13d882a
0a6c79a
51ce7be
2188b27
e0bc621
dc2dfa5
e61078b
9114107
eac31b2
4fc4301
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -345,6 +345,27 @@ trait CodegenSupport extends SparkPlan { | |
| * don't require shouldStop() in the loop of producing rows. | ||
| */ | ||
| def needStopCheck: Boolean = parent.needStopCheck | ||
|
|
||
| /** | ||
| * A sequence of checks which evaluate to true if the downstream Limit operators have not received | ||
| * enough records and reached the limit. If current node is a data producing node, it can leverage | ||
| * this information to stop producing data and complete the data flow earlier. Common data | ||
| * producing nodes are leaf nodes like Range and Scan, and blocking nodes like Sort and Aggregate. | ||
| * These checks should be put into the loop condition of the data producing loop. | ||
| */ | ||
| def limitNotReachedChecks: Seq[String] = parent.limitNotReachedChecks | ||
|
|
||
| /** | ||
| * A helper method to generate the data producing loop condition according to the | ||
| * limit-not-reached checks. | ||
| */ | ||
| final def limitNotReachedCond: String = { | ||
| if (parent.limitNotReachedChecks.isEmpty) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just one thought: since we propagate (correctly) the The reason I'd like to do this is to enforce that we are not introducing the same limit condition check more than once, in more than one operator, which would be useless and may cause (small) perf issue. WDYT?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not very useful to enforce that. The consequence is so minor and I don't think it's worth the complexity. I want to have a simple and robust framework for the limit optimization first.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes, I 100%, that's why I'd like to early detect all the possible situations which we are not thinking as possible but may happen in corner cases we are not considering. What I am suggesting here is to enforce and fail that for testing only of course, in production we shouldn't do anything similar. |
||
| "" | ||
| } else { | ||
| parent.limitNotReachedChecks.mkString("", " && ", " &&") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I am a bit affraid about 64KB Java bytecode overflow by using |
||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -381,7 +402,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp | |
| forceInline = true) | ||
| val row = ctx.freshName("row") | ||
| s""" | ||
| | while ($input.hasNext() && !stopEarly()) { | ||
| | while ($limitNotReachedCond $input.hasNext()) { | ||
| | InternalRow $row = (InternalRow) $input.next(); | ||
| | ${consume(ctx, null, row).trim} | ||
| | if (shouldStop()) return; | ||
|
|
@@ -677,6 +698,8 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) | |
|
|
||
| override def needStopCheck: Boolean = true | ||
|
|
||
| override def limitNotReachedChecks: Seq[String] = Nil | ||
|
|
||
| override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,6 +159,13 @@ case class HashAggregateExec( | |
| // don't need a stop check before aggregating. | ||
| override def needStopCheck: Boolean = false | ||
|
|
||
| // Aggregate is a blocking operator. It needs to consume all the inputs before producing any | ||
| // output. This means, Limit operator after Aggregate will never reach its limit during the | ||
| // execution of Aggregate's upstream operators. Here we override this method to return Nil, so | ||
| // that upstream operators will not generate useless conditions (which are always evaluated to | ||
| // true) for the Limit operators after Aggregate. | ||
| override def limitNotReachedChecks: Seq[String] = Nil | ||
|
|
||
| protected override def doProduce(ctx: CodegenContext): String = { | ||
| if (groupingExpressions.isEmpty) { | ||
| doProduceWithoutKeys(ctx) | ||
|
|
@@ -705,13 +712,16 @@ case class HashAggregateExec( | |
|
|
||
| def outputFromRegularHashMap: String = { | ||
| s""" | ||
| |while ($iterTerm.next()) { | ||
| |while ($limitNotReachedCond $iterTerm.next()) { | ||
| | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | ||
| | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | ||
| | $outputFunc($keyTerm, $bufferTerm); | ||
| | | ||
| | if (shouldStop()) return; | ||
| |} | ||
| |$iterTerm.close(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is an unrelated change, right? It changes nothing in the generated code, right? just want to double-check I am not missing something (what changes is that before we were not doing the cleanup in case of limit operator, instead now we do, I see this).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it's unrelated and is a noop. |
||
| |if ($sorterTerm == null) { | ||
| | $hashMapTerm.free(); | ||
| |} | ||
| """.stripMargin | ||
| } | ||
|
|
||
|
|
@@ -728,11 +738,6 @@ case class HashAggregateExec( | |
| // output the result | ||
| $outputFromFastHashMap | ||
| $outputFromRegularHashMap | ||
|
|
||
| $iterTerm.close(); | ||
| if ($sorterTerm == null) { | ||
| $hashMapTerm.free(); | ||
| } | ||
| """ | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -378,7 +378,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | |
| val numOutput = metricTerm(ctx, "numOutputRows") | ||
|
|
||
| val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") | ||
| val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") | ||
| val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex") | ||
|
|
||
| val value = ctx.freshName("value") | ||
| val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) | ||
|
|
@@ -397,7 +397,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | |
| // within a batch, while the code in the outer loop is setting batch parameters and updating | ||
| // the metrics. | ||
|
|
||
| // Once number == batchEnd, it's time to progress to the next batch. | ||
| // Once nextIndex == batchEnd, it's time to progress to the next batch. | ||
| val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") | ||
|
|
||
| // How many values should still be generated by this range operator. | ||
|
|
@@ -421,13 +421,13 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | |
| | | ||
| | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); | ||
| | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { | ||
| | $number = Long.MAX_VALUE; | ||
| | $nextIndex = Long.MAX_VALUE; | ||
| | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { | ||
| | $number = Long.MIN_VALUE; | ||
| | $nextIndex = Long.MIN_VALUE; | ||
| | } else { | ||
| | $number = st.longValue(); | ||
| | $nextIndex = st.longValue(); | ||
| | } | ||
| | $batchEnd = $number; | ||
| | $batchEnd = $nextIndex; | ||
| | | ||
| | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) | ||
| | .multiply(step).add(start); | ||
|
|
@@ -440,7 +440,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | |
| | } | ||
| | | ||
| | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract( | ||
| | $BigInt.valueOf($number)); | ||
| | $BigInt.valueOf($nextIndex)); | ||
| | $numElementsTodo = startToEnd.divide(step).longValue(); | ||
| | if ($numElementsTodo < 0) { | ||
| | $numElementsTodo = 0; | ||
|
|
@@ -452,46 +452,73 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | |
|
|
||
| val localIdx = ctx.freshName("localIdx") | ||
| val localEnd = ctx.freshName("localEnd") | ||
| val range = ctx.freshName("range") | ||
| val shouldStop = if (parent.needStopCheck) { | ||
| s"if (shouldStop()) { $number = $value + ${step}L; return; }" | ||
| s"if (shouldStop()) { $nextIndex = $value + ${step}L; return; }" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in this case we are not very accurate in the metrics right? I mean we always say that we are returning a full batch, even though we have consumed less rows than a batch. What about updating the metrics before returning? Something like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right about the problem, but I'm not going to touch this part in this PR. Note that this PR focuses on limit whole stage codegen. Personally I feel it's ok to make the metrics a little inaccurate for better performance, we can discuss it later in other PRs.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW I do have a local branch that fixed this problem, I just don't have time to benchmark it yet. I'll send it out later and let's move the discussion there.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure why you need a benchmark for this (unless you did something different from what I have suggested in the earlier comment). In that case it is a single metric update which happens only when stopping, it shouldn't introduce any significant overhead. Am I missing something? Anyway let's move the discussion to the next PR then, thanks.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Anyway, JVM JIT is mysterious and we need to be super careful when updating this kind of hot loops. That said, I'm not confident of any changes to the hot loop without a benchmark.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, let's get back to this eventually later, this is anyway not worse than before.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for late comment. It would be good to discuss detail in another PR. At first, I agree with necessary of benchmarking. Here are my thoughts.
|
||
| } else { | ||
| "// shouldStop check is eliminated" | ||
| } | ||
| val loopCondition = if (limitNotReachedChecks.isEmpty) { | ||
| "true" | ||
| } else { | ||
| limitNotReachedChecks.mkString(" && ") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I am a bit affraid about 64KB Java bytecode overflow by using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is whole-stage-codege. If bytecode overfolow happens, we will fallback |
||
| } | ||
|
|
||
| // An overview of the Range processing. | ||
| // | ||
| // For each partition, the Range task needs to produce records from partition start(inclusive) | ||
| // to end(exclusive). For better performance, we separate the partition range into batches, and | ||
| // use 2 loops to produce data. The outer while loop is used to iterate batches, and the inner | ||
| // for loop is used to iterate records inside a batch. | ||
| // | ||
| // `nextIndex` tracks the index of the next record that is going to be consumed, initialized | ||
| // with partition start. `batchEnd` tracks the end index of the current batch, initialized | ||
| // with `nextIndex`. In the outer loop, we first check if `nextIndex == batchEnd`. If it's true, | ||
| // it means the current batch is fully consumed, and we will update `batchEnd` to process the | ||
| // next batch. If `batchEnd` reaches partition end, exit the outer loop. Finally we enter the | ||
| // inner loop. Note that, when we enter inner loop, `nextIndex` must be different from | ||
| // `batchEnd`, otherwise we already exit the outer loop. | ||
| // | ||
| // The inner loop iterates from 0 to `localEnd`, which is calculated by | ||
| // `(batchEnd - nextIndex) / step`. Since `batchEnd` is increased by `nextBatchTodo * step` in | ||
| // the outer loop, and initialized with `nextIndex`, so `batchEnd - nextIndex` is always | ||
| // divisible by `step`. The `nextIndex` is increased by `step` during each iteration, and ends | ||
| // up being equal to `batchEnd` when the inner loop finishes. | ||
| // | ||
| // The inner loop can be interrupted, if the query has produced at least one result row, so that | ||
| // we don't buffer too many result rows and waste memory. It's ok to interrupt the inner loop, | ||
| // because `nextIndex` will be updated before interrupting. | ||
|
|
||
| s""" | ||
| | // initialize Range | ||
| | if (!$initTerm) { | ||
| | $initTerm = true; | ||
| | $initRangeFuncName(partitionIndex); | ||
| | } | ||
| | | ||
| | while (true) { | ||
| | long $range = $batchEnd - $number; | ||
| | if ($range != 0L) { | ||
| | int $localEnd = (int)($range / ${step}L); | ||
| | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { | ||
| | long $value = ((long)$localIdx * ${step}L) + $number; | ||
| | ${consume(ctx, Seq(ev))} | ||
| | $shouldStop | ||
| | while ($loopCondition) { | ||
| | if ($nextIndex == $batchEnd) { | ||
| | long $nextBatchTodo; | ||
| | if ($numElementsTodo > ${batchSize}L) { | ||
| | $nextBatchTodo = ${batchSize}L; | ||
| | $numElementsTodo -= ${batchSize}L; | ||
| | } else { | ||
| | $nextBatchTodo = $numElementsTodo; | ||
| | $numElementsTodo = 0; | ||
| | if ($nextBatchTodo == 0) break; | ||
| | } | ||
| | $number = $batchEnd; | ||
| | $numOutput.add($nextBatchTodo); | ||
| | $inputMetrics.incRecordsRead($nextBatchTodo); | ||
| | $batchEnd += $nextBatchTodo * ${step}L; | ||
| | } | ||
| | | ||
| | $taskContext.killTaskIfInterrupted(); | ||
| | | ||
| | long $nextBatchTodo; | ||
| | if ($numElementsTodo > ${batchSize}L) { | ||
| | $nextBatchTodo = ${batchSize}L; | ||
| | $numElementsTodo -= ${batchSize}L; | ||
| | } else { | ||
| | $nextBatchTodo = $numElementsTodo; | ||
| | $numElementsTodo = 0; | ||
| | if ($nextBatchTodo == 0) break; | ||
| | int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L); | ||
|
||
| | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { | ||
| | long $value = ((long)$localIdx * ${step}L) + $nextIndex; | ||
| | ${consume(ctx, Seq(ev))} | ||
| | $shouldStop | ||
| | } | ||
| | $numOutput.add($nextBatchTodo); | ||
| | $inputMetrics.incRecordsRead($nextBatchTodo); | ||
| | | ||
| | $batchEnd += $nextBatchTodo * ${step}L; | ||
| | $nextIndex = $batchEnd; | ||
| | $taskContext.killTaskIfInterrupted(); | ||
| | } | ||
| """.stripMargin | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,15 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode | |
| } | ||
| } | ||
|
|
||
| object BaseLimitExec { | ||
| private val curId = new java.util.concurrent.atomic.AtomicInteger() | ||
|
|
||
| def newLimitCountTerm(): String = { | ||
| val id = curId.getAndIncrement() | ||
| s"_limit_counter_$id" | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see |
||
| } | ||
|
|
||
| /** | ||
| * Helper trait which defines methods that are shared by both | ||
| * [[LocalLimitExec]] and [[GlobalLimitExec]]. | ||
|
|
@@ -66,27 +75,25 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { | |
| // to the parent operator. | ||
| override def usedInputs: AttributeSet = AttributeSet.empty | ||
|
|
||
| private lazy val countTerm = BaseLimitExec.newLimitCountTerm() | ||
|
|
||
| override lazy val limitNotReachedChecks: Seq[String] = { | ||
| s"$countTerm < $limit" +: super.limitNotReachedChecks | ||
| } | ||
|
|
||
| protected override def doProduce(ctx: CodegenContext): String = { | ||
| child.asInstanceOf[CodegenSupport].produce(ctx, this) | ||
| } | ||
|
|
||
| override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { | ||
| val stopEarly = | ||
| ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false | ||
|
|
||
| ctx.addNewFunction("stopEarly", s""" | ||
| @Override | ||
| protected boolean stopEarly() { | ||
| return $stopEarly; | ||
| } | ||
| """, inlineToOuterClass = true) | ||
| val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 | ||
| // The counter name is already obtained by the upstream operators via `limitNotReachedChecks`. | ||
| // Here we have to inline it to not change its name. This is fine as we won't have many limit | ||
| // operators in one query. | ||
| ctx.addMutableState(CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because the counter variable name is decided before we obtain the |
||
| s""" | ||
| | if ($countTerm < $limit) { | ||
| | $countTerm += 1; | ||
| | ${consume(ctx, input)} | ||
| | } else { | ||
| | $stopEarly = true; | ||
| | } | ||
| """.stripMargin | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that all blocking operators will have this behavior. Shall we rather have a
blockingOperatorflag def and make this a final function incorporating this logic there?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only done in Sort and Aggregate currently. I don't want to overdesign it until there are more use cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am fine to do it later, but I'd like to avoid to have other places where we duplicate this logic in the future in order to avoid possible mistakes.