Skip to content
Closed
Prev Previous commit
Next Next commit
createDataFrame uses ArrowStreamPandasSerializer
  • Loading branch information
BryanCutler committed Mar 14, 2019
commit 93bb83151ad84e72ebcd13ed236a8d6ce4f69c95
25 changes: 13 additions & 12 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
"""
def __init__(self, send_start_stream=True):
self._send_start_stream = send_start_stream

def _init_dump_stream(self, stream):
"""
Called just before writing an Arrow stream
"""
# NOTE: this is required by Pandas UDFs to be called after creating first record batch so
# that any errors can be sent back to the JVM, but not interfere with the Arrow stream
if self._send_start_stream:
write_int(SpecialLengths.START_ARROW_STREAM, stream)

def dump_stream(self, iterator, stream):
import pyarrow as pa
Expand All @@ -242,10 +253,6 @@ def load_stream(self, stream):
for batch in reader:
yield batch

def _init_dump_stream(self, stream):
"""Called just before writing an Arrow stream"""
pass

def __repr__(self):
return "ArrowStreamSerializer"

Expand Down Expand Up @@ -338,8 +345,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
Serializes Pandas.Series as Arrow data with Arrow streaming format.
"""

def __init__(self, timezone, safecheck, assign_cols_by_name):
super(ArrowStreamPandasSerializer, self).__init__()
def __init__(self, timezone, safecheck, assign_cols_by_name, send_start_stream=True):
super(ArrowStreamPandasSerializer, self).__init__(send_start_stream)
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
Expand All @@ -352,12 +359,6 @@ def arrow_to_pandas(self, arrow_column):
s = _check_series_localize_timestamps(s, self._timezone)
return s

def _init_dump_stream(self, stream):
# Override to signal the start of writing an Arrow stream
# NOTE: this is required by Pandas UDFs to be called after creating first record batch so
# that any errors can be sent back to the JVM, but not interfere with the Arrow stream
write_int(SpecialLengths.START_ARROW_STREAM, stream)

def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
Expand Down
37 changes: 19 additions & 18 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,15 +530,24 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowStreamSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from pyspark.serializers import ArrowStreamPandasSerializer
from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version

require_minimum_pandas_version()
require_minimum_pyarrow_version()

from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
import pyarrow as pa

# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only since pyarrow 0.12.0, I can check into a workaround although it might be a good time to bump the minimum pyarrow version

struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
schema = struct

# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
Expand All @@ -555,32 +564,24 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))

# Create Arrow record batches
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone, safecheck, col_by_name)
for pdf_slice in pdf_slices]

# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
if isinstance(schema, (list, tuple)):
struct = from_arrow_schema(batches[0].schema)
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
# Create list of Arrow (columns, type) for serializer dump_stream
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
for pdf_slice in pdf_slices]

jsqlContext = self._wrapped._jsqlContext

safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name, send_start_stream=False)

def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)

def create_RDD_server():
return self._jvm.ArrowRDDServer(jsqlContext)

# Create Spark DataFrame from Arrow stream file, using one batch per partition
jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func,
create_RDD_server)
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer}
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
Expand Down