diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index ac6b6f68aedac..5287826c1b4ec 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -22,6 +22,7 @@ import tempfile import time import unittest +import uuid from typing import cast from pyspark.sql import SparkSession, Row @@ -1176,6 +1177,41 @@ def test_df_show(self): with self.assertRaisesRegex(TypeError, "Parameter 'truncate=foo'"): df.show(truncate="foo") + def test_df_is_empty(self): + # SPARK-39084: Fix df.rdd.isEmpty() resulting in JVM crash. + + # This particular example of DataFrame reproduces an issue in isEmpty call + # which could result in JVM crash. + data = [] + for t in range(0, 10000): + id = str(uuid.uuid4()) + if t == 0: + for i in range(0, 99): + data.append((id,)) + elif t < 10: + for i in range(0, 75): + data.append((id,)) + elif t < 100: + for i in range(0, 50): + data.append((id,)) + elif t < 1000: + for i in range(0, 25): + data.append((id,)) + else: + for i in range(0, 10): + data.append((id,)) + + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + try: + df = self.spark.createDataFrame(data, ["col"]) + df.coalesce(1).write.parquet(tmpPath) + + res = self.spark.read.parquet(tmpPath).groupBy("col").count() + self.assertFalse(res.rdd.isEmpty()) + finally: + shutil.rmtree(tmpPath) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 6664acf957263..8d2f788e05cc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} +import org.apache.spark.{ContextAwareIterator, TaskContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -301,7 +302,7 @@ object EvaluatePython { def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { rdd.mapPartitions { iter => registerPicklers() // let it called in executor - new SerDeUtil.AutoBatchedPickler(iter) + new SerDeUtil.AutoBatchedPickler(new ContextAwareIterator(TaskContext.get, iter)) } } }