Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,6 @@ public void append(InternalRow row) {
currentRows.add(row);
}

/**
* Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]].
*
* If it returns true, the caller should exit the loop that [[InputAdapter]] generates.
* This interface is mainly used to limit the number of input rows.
*/
public boolean stopEarly() {
return false;
}

/**
* Returns whether `processNext()` should stop processing next row from `input` or not.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
|if ($batch == null) {
| $nextBatchFuncName();
|}
|while ($batch != null) {
|while ($batch != null$keepProducingDataCond) {
| int $numRows = $batch.numRows();
| int $localEnd = $numRows - $idx;
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
Expand Down Expand Up @@ -166,7 +166,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
}
val inputRow = if (needsUnsafeRowConversion) null else row
s"""
|while ($input.hasNext()) {
|while ($input.hasNext()$keepProducingDataCond) {
| InternalRow $row = (InternalRow) $input.next();
| $numOutputRows.add(1);
| ${consume(ctx, outputVars, inputRow).trim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ case class SortExec(
// a stop check before sorting.
override def needStopCheck: Boolean = false

// Sort operator always consumes all the input rows before outputting any result, so its upstream
// operators can keep producing data, even if there is a limit after Sort.
override def conditionsOfKeepProducingData: Seq[String] = Nil

override protected def doProduce(ctx: CodegenContext): String = {
val needToSort =
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
Expand Down Expand Up @@ -172,7 +176,7 @@ case class SortExec(
| $needToSort = false;
| }
|
| while ($sortedIterator.hasNext()) {
| while ($sortedIterator.hasNext()$keepProducingDataCond) {
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
| ${consume(ctx, null, outputRow)}
| if (shouldStop()) return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,16 @@ trait CodegenSupport extends SparkPlan {
* don't require shouldStop() in the loop of producing rows.
*/
def needStopCheck: Boolean = parent.needStopCheck

def conditionsOfKeepProducingData: Seq[String] = parent.conditionsOfKeepProducingData
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have described simply what this two method are here?


final protected def keepProducingDataCond: String = {
if (parent.conditionsOfKeepProducingData.isEmpty) {
""
} else {
parent.conditionsOfKeepProducingData.mkString(" && ", " && ", "")
}
}
}


Expand Down Expand Up @@ -381,7 +391,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
forceInline = true)
val row = ctx.freshName("row")
s"""
| while ($input.hasNext() && !stopEarly()) {
| while ($input.hasNext()$keepProducingDataCond) {
| InternalRow $row = (InternalRow) $input.next();
| ${consume(ctx, null, row).trim}
| if (shouldStop()) return;
Expand Down Expand Up @@ -677,6 +687,8 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)

override def needStopCheck: Boolean = true

override def conditionsOfKeepProducingData: Seq[String] = Nil

override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ case class HashAggregateExec(
// don't need a stop check before aggregating.
override def needStopCheck: Boolean = false

// Aggregate operator always consumes all the input rows before outputting any result, so its
// upstream operators can keep producing data, even if there is a limit after Aggregate.
Copy link
Member

@viirya viirya Oct 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not looked at this in details. But if there is limit before Aggregate? We should not consume all input rows.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's say the query is range -> limit -> agg -> limit.

So agg does consume all the inputs, from the first limit. The range will have a stop check w.r.t. to first limit, not the second limit. If there is no limit before agg, then range will not have a stop check.

override def conditionsOfKeepProducingData: Seq[String] = Nil

protected override def doProduce(ctx: CodegenContext): String = {
if (groupingExpressions.isEmpty) {
doProduceWithoutKeys(ctx)
Expand Down Expand Up @@ -705,13 +709,16 @@ case class HashAggregateExec(

def outputFromRegularHashMap: String = {
s"""
|while ($iterTerm.next()) {
|while ($iterTerm.next()$keepProducingDataCond) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I only add the stop check for regular hash map. The fast hash map is small and all in memory, it's ok to always output all of it.

| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
| $outputFunc($keyTerm, $bufferTerm);
|
| if (shouldStop()) return;
|}
|$iterTerm.close();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an unrelated change, right? It changes nothing in the generated code, right? just want to double-check I am not missing something (what changes is that before we were not doing the cleanup in case of limit operator, instead now we do, I see this).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's unrelated and is a noop. outputFromRowBasedMap and outputFromVectorizedMap put the resource closing at the end, I want to be consistent here.

|if ($sorterTerm == null) {
| $hashMapTerm.free();
|}
""".stripMargin
}

Expand All @@ -728,11 +735,6 @@ case class HashAggregateExec(
// output the result
$outputFromFastHashMap
$outputFromRegularHashMap

$iterTerm.close();
if ($sorterTerm == null) {
$hashMapTerm.free();
}
"""
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val numOutput = metricTerm(ctx, "numOutputRows")

val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number")
val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")

val value = ctx.freshName("value")
val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
Expand All @@ -397,7 +397,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
// within a batch, while the code in the outer loop is setting batch parameters and updating
// the metrics.

// Once number == batchEnd, it's time to progress to the next batch.
// Once nextIndex == batchEnd, it's time to progress to the next batch.
val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")

// How many values should still be generated by this range operator.
Expand All @@ -421,13 +421,13 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
| $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
| if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
| $number = Long.MAX_VALUE;
| $nextIndex = Long.MAX_VALUE;
| } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
| $number = Long.MIN_VALUE;
| $nextIndex = Long.MIN_VALUE;
| } else {
| $number = st.longValue();
| $nextIndex = st.longValue();
| }
| $batchEnd = $number;
| $batchEnd = $nextIndex;
|
| $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
| .multiply(step).add(start);
Expand All @@ -440,7 +440,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| }
|
| $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
| $BigInt.valueOf($number));
| $BigInt.valueOf($nextIndex));
| $numElementsTodo = startToEnd.divide(step).longValue();
| if ($numElementsTodo < 0) {
| $numElementsTodo = 0;
Expand All @@ -452,46 +452,68 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)

val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val range = ctx.freshName("range")
val shouldStop = if (parent.needStopCheck) {
s"if (shouldStop()) { $number = $value + ${step}L; return; }"
s"if (shouldStop()) { $nextIndex = $value + ${step}L; return; }"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case we are not very accurate in the metrics right? I mean we always say that we are returning a full batch, even though we have consumed less rows than a batch.

What about updating the metrics before returning? Something like $inputMetrics.incRecordsRead($localIdx - $localEnd);?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right about the problem, but I'm not going to touch this part in this PR. Note that this PR focuses on limit whole stage codegen.

Personally I feel it's ok to make the metrics a little inaccurate for better performance, we can discuss it later in other PRs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I do have a local branch that fixed this problem, I just don't have time to benchmark it yet. I'll send it out later and let's move the discussion there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why you need a benchmark for this (unless you did something different from what I have suggested in the earlier comment). In that case it is a single metric update which happens only when stopping, it shouldn't introduce any significant overhead. Am I missing something? Anyway let's move the discussion to the next PR then, thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like $inputMetrics.incRecordsRead($localIdx - $localEnd);?

localIdx is purely local to the loop, if we access it outside of the loop, we need to define localIdx outside of loop as well. This may have some performance penalty. cc @kiszk

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but shouldStop is called local to the loop, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldStop is called local, but metrics updating is not.

Anyway, JVM JIT is mysterious and we need to be super careful when updating this kind of hot loops. That said, I'm not confident of any changes to the hot loop without a benchmark.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, let's get back to this eventually later, this is anyway not worse than before.

Copy link
Member

@kiszk kiszk Oct 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late comment. It would be good to discuss detail in another PR.

At first, I agree with necessary of benchmarking. Here are my thoughts.

  1. I think that localIdx can be defined as local variable outside of the loop. Or, how about storing localIdx to another local variable only if parent.needStopCheck is true.
  2. Since shouldStop() is simply without updating, we expect the JIT applies inlining and some optimizations.
  3. If we want to call incRecordRead, it would be good to exit a loop using break and then call incRecordRead.

} else {
"// shouldStop check is eliminated"
}

// An overview of the Range processing.
//
// For each partition, the Range task needs to produce records from partition start(inclusive)
// to end(exclusive). For better performance, we separate the partition range into batches, and
// use 2 loops to produce data. The outer while loop is used to iterate batches, and the inner
// for loop is used to iterate records inside a batch.
//
// `nextIndex` tracks the index of the next record that is going to be consumed, initialized
// with partition start. `batchEnd` tracks the end index of the current batch, initialized
// with `nextIndex`. In the outer loop, we first check if `nextIndex == batchEnd`. If it's true,
// it means the current batch is fully consumed, and we will update `batchEnd` to process the
// next batch. If `batchEnd` reaches partition end, exit the outer loop. finally we enter the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capital case for finally

// inner loop. Note that, when we enter inner loop, `nextIndex` must be different from
// `batchEnd`, otherwise the outer loop should already exits.
//
// The inner loop iterates from 0 to `localEnd`, which is calculated by
// `(batchEnd - nextIndex) / step`. Since `batchEnd` is increased by `nextBatchTodo * step` in
// the outer loop, and initialized with `nextIndex`, so `batchEnd - nextIndex` is always
// divisible by `step`. The `nextIndex` is increased by `step` during each iteration, and ends
// up being equal to `batchEnd` when the inner loop finishes.
//
// The inner loop can be interrupted, if the query has produced at least one result row, so that
// we don't buffer too many result rows and waste memory. It's ok to interrupt the inner loop,
// because `nextIndex` will be updated before interrupting.

s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| $initRangeFuncName(partitionIndex);
| }
|
| while (true) {
| long $range = $batchEnd - $number;
| if ($range != 0L) {
| int $localEnd = (int)($range / ${step}L);
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| long $value = ((long)$localIdx * ${step}L) + $number;
| ${consume(ctx, Seq(ev))}
| $shouldStop
| while (true$keepProducingDataCond) {
| if ($nextIndex == $batchEnd) {
| long $nextBatchTodo;
| if ($numElementsTodo > ${batchSize}L) {
| $nextBatchTodo = ${batchSize}L;
| $numElementsTodo -= ${batchSize}L;
| } else {
| $nextBatchTodo = $numElementsTodo;
| $numElementsTodo = 0;
| if ($nextBatchTodo == 0) break;
| }
| $number = $batchEnd;
| $numOutput.add($nextBatchTodo);
| $inputMetrics.incRecordsRead($nextBatchTodo);
| $batchEnd += $nextBatchTodo * ${step}L;
| }
|
| $taskContext.killTaskIfInterrupted();
|
| long $nextBatchTodo;
| if ($numElementsTodo > ${batchSize}L) {
| $nextBatchTodo = ${batchSize}L;
| $numElementsTodo -= ${batchSize}L;
| } else {
| $nextBatchTodo = $numElementsTodo;
| $numElementsTodo = 0;
| if ($nextBatchTodo == 0) break;
| int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L);
Copy link
Contributor Author

@cloud-fan cloud-fan Oct 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change here simply moves the inner loop after the batchEnd and metrics update, so that we can get correct metrics when we stop earlier because of limit.

| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| long $value = ((long)$localIdx * ${step}L) + $nextIndex;
| ${consume(ctx, Seq(ev))}
| $shouldStop
| }
| $numOutput.add($nextBatchTodo);
| $inputMetrics.incRecordsRead($nextBatchTodo);
|
| $batchEnd += $nextBatchTodo * ${step}L;
| $nextIndex = $batchEnd;
| $taskContext.killTaskIfInterrupted();
| }
""".stripMargin
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ case class SortMergeJoinExec(
}

s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
|while (findNextInnerJoinRows($leftInput, $rightInput)$keepProducingDataCond) {
| ${leftVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
Expand Down
28 changes: 16 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
}
}

object BaseLimitExec {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

def newLimitCountTerm(): String = {
val id = curId.getAndIncrement()
s"_limit_counter_$id"
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use freshName?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no CodegenContext here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see MapObjects.apply as an existing example.

}

/**
* Helper trait which defines methods that are shared by both
* [[LocalLimitExec]] and [[GlobalLimitExec]].
Expand All @@ -66,27 +75,22 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
// to the parent operator.
override def usedInputs: AttributeSet = AttributeSet.empty

private lazy val countTerm = BaseLimitExec.newLimitCountTerm()

override lazy val conditionsOfKeepProducingData: Seq[String] = {
s"$countTerm < $limit" +: super.conditionsOfKeepProducingData
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this is sub-optimal for adjacent limits, but I think it's fine as optimizer will merge adjacent limits.

}

protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val stopEarly =
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false

ctx.addNewFunction("stopEarly", s"""
@Override
protected boolean stopEarly() {
return $stopEarly;
}
""", inlineToOuterClass = true)
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0
ctx.addMutableState(CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to forceInline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the counter variable name is decided before we obtain the CodegenContext. If we don't inline here, we need a way to notify the upstream operators about the counter name, which is hard to do.

s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
| ${consume(ctx, input)}
| } else {
| $stopEarly = true;
| }
""".stripMargin
}
Expand Down
Loading