Skip to content

Commit 093728e

Browse files
committed
Call clear after each batch.
1 parent 75cf369 commit 093728e

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

python/pyspark/sql/tests.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4680,6 +4680,26 @@ def test_supported_types(self):
46804680
self.assertPandasEqual(expected2, result2)
46814681
self.assertPandasEqual(expected3, result3)
46824682

4683+
def test_array_type_correct(self):
4684+
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
4685+
4686+
df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
4687+
4688+
output_schema = StructType(
4689+
[StructField('id', LongType()),
4690+
StructField('v', IntegerType()),
4691+
StructField('arr', ArrayType(LongType()))])
4692+
4693+
udf = pandas_udf(
4694+
lambda pdf: pdf,
4695+
output_schema,
4696+
PandasUDFType.GROUPED_MAP
4697+
)
4698+
4699+
result = df.groupby('id').apply(udf).sort('id').toPandas()
4700+
expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
4701+
self.assertPandasEqual(expected, result)
4702+
46834703
def test_register_grouped_map_udf(self):
46844704
from pyspark.sql.functions import pandas_udf, PandasUDFType
46854705

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ private[arrow] class ArrayWriter(
311311
override def reset(): Unit = {
312312
super.reset()
313313
elementWriter.reset()
314+
valueVector.clear()
314315
}
315316
}
316317

0 commit comments

Comments
 (0)