diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 9d1636ccf271..b41a4ff76667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -454,6 +454,7 @@ case class RowToColumnarExec(child: SparkPlan) extends UnaryExecNode { override def next(): ColumnarBatch = { cb.setNumRows(0) + vectors.foreach(_.reset()) var rowCount = 0 while (rowCount < numRows && rowIterator.hasNext) { val row = rowIterator.next() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index b8df6f2bebf5..2a4c15233fe3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS import org.apache.spark.sql.types.{DataType, Decimal, IntegerType, LongType, Metadata, StructType} import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch, ColumnarMap, ColumnVector} @@ -171,6 +172,30 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } + test("reset column vectors") { + val session = SparkSession.builder() + .master("local[1]") + .config(COLUMN_BATCH_SIZE.key, 2) + .withExtensions { extensions => + extensions.injectColumnar(session => + MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } + .getOrCreate() + + try { + assert(session.sessionState.columnarRules.contains( + MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) + import session.sqlContext.implicits._ + + val input = Seq((100L), (200L), (300L)) + val data = input.toDF("vals").repartition(1) + val df = data.selectExpr("vals + 1") + val result = df.collect() + assert(result sameElements input.map(x => Row(x + 2))) + } finally { + stop(session) + } + } + test("use custom class for extensions") { val session = SparkSession.builder() .master("local[1]")