Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
limit operation within whole stage codegen should not consume all the…
… inputs.
  • Loading branch information
viirya committed Sep 22, 2018
commit 12703bded143002be417ffa247eef4a970ffd54c
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ case class HashAggregateExec(
val aggTime = metricTerm(ctx, "aggTime")
val beforeAgg = ctx.freshName("beforeAgg")
s"""
| while (!$initAgg) {
| while (!$initAgg && !stopEarly()) {
| $initAgg = true;
| long $beforeAgg = System.nanoTime();
| $doAggFuncName();
Expand Down Expand Up @@ -665,7 +665,7 @@ case class HashAggregateExec(

def outputFromRowBasedMap: String = {
s"""
|while ($iterTermForFastHashMap.next()) {
|while ($iterTermForFastHashMap.next() && !stopEarly()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
| $outputFunc($keyTerm, $bufferTerm);
Expand All @@ -690,7 +690,7 @@ case class HashAggregateExec(
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
})
s"""
|while ($iterTermForFastHashMap.hasNext()) {
|while ($iterTermForFastHashMap.hasNext() && !stopEarly()) {
| InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
| ${generateKeyRow.code}
| ${generateBufferRow.code}
Expand All @@ -705,7 +705,7 @@ case class HashAggregateExec(

def outputFromRegularHashMap: String = {
s"""
|while ($iterTerm.next()) {
|while ($iterTerm.next() && !stopEarly()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
| $outputFunc($keyTerm, $bufferTerm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,18 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| $initRangeFuncName(partitionIndex);
| }
|
| while (true) {
| while (true && !stopEarly()) {
| long $range = $batchEnd - $number;
| if ($range != 0L) {
| int $localEnd = (int)($range / ${step}L);
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| long $value = ((long)$localIdx * ${step}L) + $number;
| $numOutput.add(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this introduce a perf regression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no worry about it since it is a simple op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very likely to hit perf regression since it's not a tight loop anymore.

We want the range operator to stop earlier for better performance, but it doesn't mean the range operator must return exactly the limit number of records. Since the range operator is already returning data in batch, I think we can stop earlier in a batch granularity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. then I should revert the numOutput change if the number of records can be a bit inaccurate.

| $inputMetrics.incRecordsRead(1);
| ${consume(ctx, Seq(ev))}
| if (stopEarly()) {
| break;
| }
| $shouldStop
| }
| $number = $batchEnd;
Expand All @@ -488,9 +493,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| $numElementsTodo = 0;
| if ($nextBatchTodo == 0) break;
| }
| $numOutput.add($nextBatchTodo);
| $inputMetrics.incRecordsRead($nextBatchTodo);
|
| $batchEnd += $nextBatchTodo * ${step}L;
| }
""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
| if ($countTerm == $limit) {
| $stopEarly = true;
| }
| ${consume(ctx, input)}
| } else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to remove this? Isn't it safer to let it here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't execute into it. If we do, there should be a bug.

| $stopEarly = true;
| }
""".stripMargin
}
Expand Down
31 changes: 29 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.{aggregate, FilterExec, RangeExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -2849,6 +2848,34 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val result = ds.flatMap(_.bar).distinct
result.rdd.isEmpty
}

test("SPARK-25497: limit operation within whole stage codegen should not " +
"consume all the inputs") {

val aggDF = spark.range(0, 100, 1, 1)
.groupBy("id")
.count().limit(1).filter('count > 0)
aggDF.collect()
val aggNumRecords = aggDF.queryExecution.sparkPlan.collect {
case h: HashAggregateExec => h
}.map { hashNode =>
hashNode.metrics("numOutputRows").value
}.sum
// The first hash aggregate node outputs 100 records.
// The second hash aggregate before local limit outputs 1 record.
assert(aggNumRecords == 101)

val filterDF = spark.range(0, 100, 1, 1).filter('id >= 0)
.selectExpr("id + 1 as id2").limit(1).filter('id > 50)
filterDF.collect()
val filterNumRecords = filterDF.queryExecution.sparkPlan.collect {
case f @ FilterExec(_, r: RangeExec) => (f, r)
}.map { case (filterNode, rangeNode) =>
(filterNode.metrics("numOutputRows").value, rangeNode.metrics("numOutputRows").value)
}.head
// RangeNode and FilterNode both output 1 record.
assert(filterNumRecords == Tuple2(1, 1))
}
}

case class Foo(bar: Option[String])