Skip to content

Commit 51ce7be

Browse files
committed
address comment
1 parent 0a6c79a commit 51ce7be

File tree

7 files changed

+36
-23
lines changed

7 files changed

+36
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
136136
|if ($batch == null) {
137137
| $nextBatchFuncName();
138138
|}
139-
|while ($batch != null$keepProducingDataCond) {
139+
|while ($batch != null$limitNotReachedCond) {
140140
| int $numRows = $batch.numRows();
141141
| int $localEnd = $numRows - $idx;
142142
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
@@ -166,7 +166,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
166166
}
167167
val inputRow = if (needsUnsafeRowConversion) null else row
168168
s"""
169-
|while ($input.hasNext()$keepProducingDataCond) {
169+
|while ($input.hasNext()$limitNotReachedCond) {
170170
| InternalRow $row = (InternalRow) $input.next();
171171
| $numOutputRows.add(1);
172172
| ${consume(ctx, outputVars, inputRow).trim}

sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,12 @@ case class SortExec(
132132
// a stop check before sorting.
133133
override def needStopCheck: Boolean = false
134134

135-
// Sort is a blocking operator. It needs to consume all the inputs before producing any
136-
// output. This means, Limit after Sort has no effect to Sort's upstream operators.
137-
// Here we override this method to return Nil, so that upstream operators will not generate
138-
// unnecessary conditions (which is always evaluated to false) for the Limit after Sort.
139-
override def conditionsOfKeepProducingData: Seq[String] = Nil
135+
// Sort is a blocking operator. It needs to consume all the inputs before producing any output.
136+
// This means, Limit operator after Sort will never reach its limit during the execution of Sort's
137+
// upstream operators. Here we override this method to return Nil, so that upstream operators will
138+
// not generate useless conditions (which are always evaluated to false) for the Limit operators
139+
// after Sort.
140+
override def limitNotReachedChecks: Seq[String] = Nil
140141

141142
override protected def doProduce(ctx: CodegenContext): String = {
142143
val needToSort =
@@ -178,7 +179,7 @@ case class SortExec(
178179
| $needToSort = false;
179180
| }
180181
|
181-
| while ($sortedIterator.hasNext()$keepProducingDataCond) {
182+
| while ($sortedIterator.hasNext()$limitNotReachedCond) {
182183
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
183184
| ${consume(ctx, null, outputRow)}
184185
| if (shouldStop()) return;

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,24 @@ trait CodegenSupport extends SparkPlan {
346346
*/
347347
def needStopCheck: Boolean = parent.needStopCheck
348348

349-
def conditionsOfKeepProducingData: Seq[String] = parent.conditionsOfKeepProducingData
349+
/**
350+
* A sequence of checks which evaluate to true if the downstream Limit operators have not received
351+
* enough records and reached the limit. If current node is a data producing node, it can leverage
352+
* this information to stop producing data and complete the data flow earlier. Common data
353+
* producing nodes are leaf nodes like Range and Scan, and blocking nodes like Sort and Aggregate.
354+
* These checks should be put into the loop condition of the data producing loop.
355+
*/
356+
def limitNotReachedChecks: Seq[String] = parent.limitNotReachedChecks
350357

351-
final protected def keepProducingDataCond: String = {
352-
if (parent.conditionsOfKeepProducingData.isEmpty) {
358+
/**
359+
* A helper method to generate the data producing loop condition according to the
360+
* limit-not-reached checks.
361+
*/
362+
final def limitNotReachedCond: String = {
363+
if (parent.limitNotReachedChecks.isEmpty) {
353364
""
354365
} else {
355-
parent.conditionsOfKeepProducingData.mkString(" && ", " && ", "")
366+
parent.limitNotReachedChecks.mkString(" && ", " && ", "")
356367
}
357368
}
358369
}
@@ -391,7 +402,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
391402
forceInline = true)
392403
val row = ctx.freshName("row")
393404
s"""
394-
| while ($input.hasNext()$keepProducingDataCond) {
405+
| while ($input.hasNext()$limitNotReachedCond) {
395406
| InternalRow $row = (InternalRow) $input.next();
396407
| ${consume(ctx, null, row).trim}
397408
| if (shouldStop()) return;
@@ -687,7 +698,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
687698

688699
override def needStopCheck: Boolean = true
689700

690-
override def conditionsOfKeepProducingData: Seq[String] = Nil
701+
override def limitNotReachedChecks: Seq[String] = Nil
691702

692703
override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer])
693704
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,11 @@ case class HashAggregateExec(
160160
override def needStopCheck: Boolean = false
161161

162162
// Aggregate is a blocking operator. It needs to consume all the inputs before producing any
163-
// output. This means, Limit after Aggregate has no effect to Aggregate's upstream operators.
164-
// Here we override this method to return Nil, so that upstream operators will not generate
165-
// unnecessary conditions (which is always evaluated to false) for the Limit after Aggregate.
166-
override def conditionsOfKeepProducingData: Seq[String] = Nil
163+
// output. This means, Limit operator after Aggregate will never reach its limit during the
164+
// execution of Aggregate's upstream operators. Here we override this method to return Nil, so
165+
// that upstream operators will not generate useless conditions (which are always evaluated to
166+
// true) for the Limit operators after Aggregate.
167+
override def limitNotReachedChecks: Seq[String] = Nil
167168

168169
protected override def doProduce(ctx: CodegenContext): String = {
169170
if (groupingExpressions.isEmpty) {
@@ -711,7 +712,7 @@ case class HashAggregateExec(
711712

712713
def outputFromRegularHashMap: String = {
713714
s"""
714-
|while ($iterTerm.next()$keepProducingDataCond) {
715+
|while ($iterTerm.next()$limitNotReachedCond) {
715716
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
716717
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
717718
| $outputFunc($keyTerm, $bufferTerm);

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
490490
| $initRangeFuncName(partitionIndex);
491491
| }
492492
|
493-
| while (true$keepProducingDataCond) {
493+
| while (true$limitNotReachedCond) {
494494
| if ($nextIndex == $batchEnd) {
495495
| long $nextBatchTodo;
496496
| if ($numElementsTodo > ${batchSize}L) {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ case class SortMergeJoinExec(
623623
}
624624

625625
s"""
626-
|while (findNextInnerJoinRows($leftInput, $rightInput)$keepProducingDataCond) {
626+
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
627627
| ${leftVarDecl.mkString("\n")}
628628
| ${beforeLoop.trim}
629629
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
7777

7878
private lazy val countTerm = BaseLimitExec.newLimitCountTerm()
7979

80-
override lazy val conditionsOfKeepProducingData: Seq[String] = {
81-
s"$countTerm < $limit" +: super.conditionsOfKeepProducingData
80+
override lazy val limitNotReachedChecks: Seq[String] = {
81+
s"$countTerm < $limit" +: super.limitNotReachedChecks
8282
}
8383

8484
protected override def doProduce(ctx: CodegenContext): String = {

0 commit comments

Comments
 (0)