diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 74c9c0599271..3d0511b7ba83 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -73,16 +73,6 @@ public void append(InternalRow row) { currentRows.add(row); } - /** - * Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]]. - * - * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. - * This interface is mainly used to limit the number of input rows. - */ - public boolean stopEarly() { - return false; - } - /** * Returns whether `processNext()` should stop processing next row from `input` or not. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 48abad907865..9f6b59336080 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -136,7 +136,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { |if ($batch == null) { | $nextBatchFuncName(); |} - |while ($batch != null) { + |while ($limitNotReachedCond $batch != null) { | int $numRows = $batch.numRows(); | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { @@ -166,7 +166,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val inputRow = if (needsUnsafeRowConversion) null else row s""" - |while ($input.hasNext()) { + |while ($limitNotReachedCond $input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); | ${consume(ctx, outputVars, inputRow).trim} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 0dc16ba5ce28..f1470e45f129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -39,7 +39,7 @@ case class SortExec( global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with BlockingOperatorWithCodegen { override def output: Seq[Attribute] = child.output @@ -124,14 +124,6 @@ case class SortExec( // Name of sorter variable used in codegen. private var sorterVariable: String = _ - // The result rows come from the sort buffer, so this operator doesn't need to copy its result - // even if its child does. - override def needCopyResult: Boolean = false - - // Sort operator always consumes all the input rows before outputting any result, so we don't need - // a stop check before sorting. - override def needStopCheck: Boolean = false - override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") @@ -172,7 +164,7 @@ case class SortExec( | $needToSort = false; | } | - | while ($sortedIterator.hasNext()) { + | while ($limitNotReachedCond $sortedIterator.hasNext()) { | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); | ${consume(ctx, null, outputRow)} | if (shouldStop()) return; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1fc4de9e5601..f5aee627fe90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -345,6 +345,61 @@ trait CodegenSupport extends SparkPlan { * don't require shouldStop() in the loop of producing rows. */ def needStopCheck: Boolean = parent.needStopCheck + + /** + * A sequence of checks which evaluate to true if the downstream Limit operators have not received + * enough records and reached the limit. If current node is a data producing node, it can leverage + * this information to stop producing data and complete the data flow earlier. Common data + * producing nodes are leaf nodes like Range and Scan, and blocking nodes like Sort and Aggregate. + * These checks should be put into the loop condition of the data producing loop. + */ + def limitNotReachedChecks: Seq[String] = parent.limitNotReachedChecks + + /** + * A helper method to generate the data producing loop condition according to the + * limit-not-reached checks. + */ + final def limitNotReachedCond: String = { + // InputAdapter is also a leaf node. + val isLeafNode = children.isEmpty || this.isInstanceOf[InputAdapter] + if (!isLeafNode && !this.isInstanceOf[BlockingOperatorWithCodegen]) { + val errMsg = "Only leaf nodes and blocking nodes need to call 'limitNotReachedCond' " + + "in its data producing loop." + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) + } else { + logWarning(s"[BUG] $errMsg Please open a JIRA ticket to report it.") + } + } + if (parent.limitNotReachedChecks.isEmpty) { + "" + } else { + parent.limitNotReachedChecks.mkString("", " && ", " &&") + } + } +} + +/** + * A special kind of operators which support whole stage codegen. Blocking means these operators + * will consume all the inputs first, before producing output. Typical blocking operators are + * sort and aggregate. + */ +trait BlockingOperatorWithCodegen extends CodegenSupport { + + // Blocking operators usually have some kind of buffer to keep the data before producing them, so + // then don't to copy its result even if its child does. + override def needCopyResult: Boolean = false + + // Blocking operators always consume all the input first, so its upstream operators don't need a + // stop check. + override def needStopCheck: Boolean = false + + // Blocking operators need to consume all the inputs before producing any output. This means, + // Limit operator after this blocking operator will never reach its limit during the execution of + // this blocking operator's upstream operators. Here we override this method to return Nil, so + // that upstream operators will not generate useless conditions (which are always evaluated to + // false) for the Limit operators after this blocking operator. + override def limitNotReachedChecks: Seq[String] = Nil } @@ -381,7 +436,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp forceInline = true) val row = ctx.freshName("row") s""" - | while ($input.hasNext() && !stopEarly()) { + | while ($limitNotReachedCond $input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | ${consume(ctx, null, row).trim} | if (shouldStop()) return; @@ -677,6 +732,8 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) override def needStopCheck: Boolean = true + override def limitNotReachedChecks: Seq[String] = Nil + override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 98adba50b297..6155ec9d30db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -45,7 +45,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with BlockingOperatorWithCodegen { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -151,14 +151,6 @@ case class HashAggregateExec( child.asInstanceOf[CodegenSupport].inputRDDs() } - // The result rows come from the aggregate buffer, or a single row(no grouping keys), so this - // operator doesn't need to copy its result even if its child does. - override def needCopyResult: Boolean = false - - // Aggregate operator always consumes all the input rows before outputting any result, so we - // don't need a stop check before aggregating. - override def needStopCheck: Boolean = false - protected override def doProduce(ctx: CodegenContext): String = { if (groupingExpressions.isEmpty) { doProduceWithoutKeys(ctx) @@ -705,13 +697,16 @@ case class HashAggregateExec( def outputFromRegularHashMap: String = { s""" - |while ($iterTerm.next()) { + |while ($limitNotReachedCond $iterTerm.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputFunc($keyTerm, $bufferTerm); - | | if (shouldStop()) return; |} + |$iterTerm.close(); + |if ($sorterTerm == null) { + | $hashMapTerm.free(); + |} """.stripMargin } @@ -728,11 +723,6 @@ case class HashAggregateExec( // output the result $outputFromFastHashMap $outputFromRegularHashMap - - $iterTerm.close(); - if ($sorterTerm == null) { - $hashMapTerm.free(); - } """ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 222a1b8bc730..4cd2e788ade0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -378,7 +378,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val numOutput = metricTerm(ctx, "numOutputRows") val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") + val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex") val value = ctx.freshName("value") val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) @@ -397,7 +397,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // within a batch, while the code in the outer loop is setting batch parameters and updating // the metrics. - // Once number == batchEnd, it's time to progress to the next batch. + // Once nextIndex == batchEnd, it's time to progress to the next batch. val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. @@ -421,13 +421,13 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; + | $nextIndex = Long.MAX_VALUE; | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; + | $nextIndex = Long.MIN_VALUE; | } else { - | $number = st.longValue(); + | $nextIndex = st.longValue(); | } - | $batchEnd = $number; + | $batchEnd = $nextIndex; | | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) | .multiply(step).add(start); @@ -440,7 +440,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract( - | $BigInt.valueOf($number)); + | $BigInt.valueOf($nextIndex)); | $numElementsTodo = startToEnd.divide(step).longValue(); | if ($numElementsTodo < 0) { | $numElementsTodo = 0; @@ -452,12 +452,42 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") - val range = ctx.freshName("range") val shouldStop = if (parent.needStopCheck) { - s"if (shouldStop()) { $number = $value + ${step}L; return; }" + s"if (shouldStop()) { $nextIndex = $value + ${step}L; return; }" } else { "// shouldStop check is eliminated" } + val loopCondition = if (limitNotReachedChecks.isEmpty) { + "true" + } else { + limitNotReachedChecks.mkString(" && ") + } + + // An overview of the Range processing. + // + // For each partition, the Range task needs to produce records from partition start(inclusive) + // to end(exclusive). For better performance, we separate the partition range into batches, and + // use 2 loops to produce data. The outer while loop is used to iterate batches, and the inner + // for loop is used to iterate records inside a batch. + // + // `nextIndex` tracks the index of the next record that is going to be consumed, initialized + // with partition start. `batchEnd` tracks the end index of the current batch, initialized + // with `nextIndex`. In the outer loop, we first check if `nextIndex == batchEnd`. If it's true, + // it means the current batch is fully consumed, and we will update `batchEnd` to process the + // next batch. If `batchEnd` reaches partition end, exit the outer loop. Finally we enter the + // inner loop. Note that, when we enter inner loop, `nextIndex` must be different from + // `batchEnd`, otherwise we already exit the outer loop. + // + // The inner loop iterates from 0 to `localEnd`, which is calculated by + // `(batchEnd - nextIndex) / step`. Since `batchEnd` is increased by `nextBatchTodo * step` in + // the outer loop, and initialized with `nextIndex`, so `batchEnd - nextIndex` is always + // divisible by `step`. The `nextIndex` is increased by `step` during each iteration, and ends + // up being equal to `batchEnd` when the inner loop finishes. + // + // The inner loop can be interrupted, if the query has produced at least one result row, so that + // we don't buffer too many result rows and waste memory. It's ok to interrupt the inner loop, + // because `nextIndex` will be updated before interrupting. + s""" | // initialize Range | if (!$initTerm) { @@ -465,33 +495,30 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $initRangeFuncName(partitionIndex); | } | - | while (true) { - | long $range = $batchEnd - $number; - | if ($range != 0L) { - | int $localEnd = (int)($range / ${step}L); - | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | long $value = ((long)$localIdx * ${step}L) + $number; - | ${consume(ctx, Seq(ev))} - | $shouldStop + | while ($loopCondition) { + | if ($nextIndex == $batchEnd) { + | long $nextBatchTodo; + | if ($numElementsTodo > ${batchSize}L) { + | $nextBatchTodo = ${batchSize}L; + | $numElementsTodo -= ${batchSize}L; + | } else { + | $nextBatchTodo = $numElementsTodo; + | $numElementsTodo = 0; + | if ($nextBatchTodo == 0) break; | } - | $number = $batchEnd; + | $numOutput.add($nextBatchTodo); + | $inputMetrics.incRecordsRead($nextBatchTodo); + | $batchEnd += $nextBatchTodo * ${step}L; | } | - | $taskContext.killTaskIfInterrupted(); - | - | long $nextBatchTodo; - | if ($numElementsTodo > ${batchSize}L) { - | $nextBatchTodo = ${batchSize}L; - | $numElementsTodo -= ${batchSize}L; - | } else { - | $nextBatchTodo = $numElementsTodo; - | $numElementsTodo = 0; - | if ($nextBatchTodo == 0) break; + | int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L); + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | long $value = ((long)$localIdx * ${step}L) + $nextIndex; + | ${consume(ctx, Seq(ev))} + | $shouldStop | } - | $numOutput.add($nextBatchTodo); - | $inputMetrics.incRecordsRead($nextBatchTodo); - | - | $batchEnd += $nextBatchTodo * ${step}L; + | $nextIndex = $batchEnd; + | $taskContext.killTaskIfInterrupted(); | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 66bcda891373..9bfe1a79fc1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -46,6 +46,15 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } } +object BaseLimitExec { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + def newLimitCountTerm(): String = { + val id = curId.getAndIncrement() + s"_limit_counter_$id" + } +} + /** * Helper trait which defines methods that are shared by both * [[LocalLimitExec]] and [[GlobalLimitExec]]. @@ -66,27 +75,25 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { // to the parent operator. override def usedInputs: AttributeSet = AttributeSet.empty + private lazy val countTerm = BaseLimitExec.newLimitCountTerm() + + override lazy val limitNotReachedChecks: Seq[String] = { + s"$countTerm < $limit" +: super.limitNotReachedChecks + } + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = - ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false - - ctx.addNewFunction("stopEarly", s""" - @Override - protected boolean stopEarly() { - return $stopEarly; - } - """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 + // The counter name is already obtained by the upstream operators via `limitNotReachedChecks`. + // Here we have to inline it to not change its name. This is fine as we won't have many limit + // operators in one query. + ctx.addMutableState(CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false) s""" | if ($countTerm < $limit) { | $countTerm += 1; | ${consume(ctx, input)} - | } else { - | $stopEarly = true; | } """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 085a44548848..81db3e137964 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.execution.metric import java.io.File +import scala.reflect.{classTag, ClassTag} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -518,56 +521,80 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared testMetricsDynamicPartition("parquet", "parquet", "t1") } + private def collectNodeWithinWholeStage[T <: SparkPlan : ClassTag](plan: SparkPlan): Seq[T] = { + val stages = plan.collect { + case w: WholeStageCodegenExec => w + } + assert(stages.length == 1, "The query plan should have one and only one whole-stage.") + + val cls = classTag[T].runtimeClass + stages.head.collect { + case n if n.getClass == cls => n.asInstanceOf[T] + } + } + test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") { def checkFilterAndRangeMetrics( df: DataFrame, filterNumOutputs: Int, rangeNumOutputs: Int): Unit = { - var filter: FilterExec = null - var range: RangeExec = null - val collectFilterAndRange: SparkPlan => Unit = { - case f: FilterExec => - assert(filter == null, "the query should only have one Filter") - filter = f - case r: RangeExec => - assert(range == null, "the query should only have one Range") - range = r - case _ => - } - if (SQLConf.get.wholeStageEnabled) { - df.queryExecution.executedPlan.foreach { - case w: WholeStageCodegenExec => - w.child.foreach(collectFilterAndRange) - case _ => - } - } else { - df.queryExecution.executedPlan.foreach(collectFilterAndRange) - } + val plan = df.queryExecution.executedPlan - assert(filter != null && range != null, "the query doesn't have Filter and Range") - assert(filter.metrics("numOutputRows").value == filterNumOutputs) - assert(range.metrics("numOutputRows").value == rangeNumOutputs) + val filters = collectNodeWithinWholeStage[FilterExec](plan) + assert(filters.length == 1, "The query plan should have one and only one Filter") + assert(filters.head.metrics("numOutputRows").value == filterNumOutputs) + + val ranges = collectNodeWithinWholeStage[RangeExec](plan) + assert(ranges.length == 1, "The query plan should have one and only one Range") + assert(ranges.head.metrics("numOutputRows").value == rangeNumOutputs) } - val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0) - val df2 = df.limit(2) - Seq(true, false).foreach { wholeStageEnabled => - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageEnabled.toString) { - df.collect() - checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000) - - df.queryExecution.executedPlan.foreach(_.resetMetrics()) - // For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition, - // and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces - // 4 rows, and Range produces 2000 rows. - df.queryExecution.toRdd.mapPartitions(_.take(2)).collect() - checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000) - - // Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first - // task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch). - df2.collect() - checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000) - } + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0) + df.collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000) + + df.queryExecution.executedPlan.foreach(_.resetMetrics()) + // For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition, + // and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces + // 4 rows, and Range produces 2000 rows. + df.queryExecution.toRdd.mapPartitions(_.take(2)).collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000) + + // Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first + // task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch). + val df2 = df.limit(2) + df2.collect() + checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000) + } + } + + test("SPARK-25497: LIMIT within whole stage codegen should not consume all the inputs") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + // A special query that only has one partition, so there is no shuffle and the entire query + // can be whole-stage-codegened. + val df = spark.range(0, 1500, 1, 1).limit(10).groupBy('id).count().limit(1).filter('id >= 0) + df.collect() + val plan = df.queryExecution.executedPlan + + val ranges = collectNodeWithinWholeStage[RangeExec](plan) + assert(ranges.length == 1, "The query plan should have one and only one Range") + // The Range should only produce the first batch, i.e. 1000 rows. + assert(ranges.head.metrics("numOutputRows").value == 1000) + + val aggs = collectNodeWithinWholeStage[HashAggregateExec](plan) + assert(aggs.length == 2, "The query plan should have two and only two Aggregate") + val partialAgg = aggs.filter(_.aggregateExpressions.head.mode == Partial).head + // The partial aggregate should output 10 rows, because its input is 10 rows. + assert(partialAgg.metrics("numOutputRows").value == 10) + val finalAgg = aggs.filter(_.aggregateExpressions.head.mode == Final).head + // The final aggregate should only produce 1 row, because the upstream limit only needs 1 row. + assert(finalAgg.metrics("numOutputRows").value == 1) + + val filters = collectNodeWithinWholeStage[FilterExec](plan) + assert(filters.length == 1, "The query plan should have one and only one Filter") + // The final Filter should produce 1 rows, because the input is just one row. + assert(filters.head.metrics("numOutputRows").value == 1) } } }