diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 730a4ae8d560..74c9c0599271 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -62,10 +62,14 @@ public long durationMs() { */ public abstract void init(int index, Iterator[] iters); + /* + * Attributes of the following four methods are public. Thus, they can be also accessed from + * methods in inner classes. See SPARK-23598 + */ /** * Append a row to currentRows. */ - protected void append(InternalRow row) { + public void append(InternalRow row) { currentRows.add(row); } @@ -75,7 +79,7 @@ protected void append(InternalRow row) { * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. * This interface is mainly used to limit the number of input rows. */ - protected boolean stopEarly() { + public boolean stopEarly() { return false; } @@ -84,14 +88,14 @@ protected boolean stopEarly() { * * If it returns true, the caller should exit the loop (return from processNext()). */ - protected boolean shouldStop() { + public boolean shouldStop() { return !currentRows.isEmpty(); } /** * Increase the peak execution memory for current task. */ - protected void incPeakExecutionMemory(long size) { + public void incPeakExecutionMemory(long size) { TaskContext.get().taskMetrics().incPeakExecutionMemory(size); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ef16292a8e75..fac2ac914462 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan @@ -307,4 +309,14 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { // a different query can result in codegen cache miss, that's by design } } + + test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") + for (i <- 0 until 70) { + df = df.groupBy("name").agg(avg("age").alias("age")) + } + assert(df.limit(1).collect() === Array(Row("bat", 8.0))) + } + } }