Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,8 +2923,8 @@ def assertFramesEqual(self, df_with_arrow, df_without):
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)

def test_unsupported_datatype(self):
schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)])
df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema)
schema = StructType([StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())

Expand Down
4 changes: 3 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.control.NonFatal

import org.apache.commons.lang3.StringUtils

import org.apache.spark.TaskContext
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
Expand Down Expand Up @@ -3086,7 +3087,8 @@ class Dataset[T] private[sql](
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
queryExecution.toRdd.mapPartitionsInternal { iter =>
ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch)
val context = TaskContext.get()
ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context)
}
}
}
Loading