From 02f03b75b583ee9bc7b8d59a9995a1985825422d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 18 Feb 2019 11:48:10 +0800 Subject: [PATCH 1/4] [SPARK-26887][SQL][PYTHON][NS] Create datetime.date directly instead of creating datetime64 as intermediate data. ## What changes were proposed in this pull request? Currently `DataFrame.toPandas()` with arrow enabled or `ArrowStreamPandasSerializer` for pandas UDF with pyarrow<0.12 creates `datetime64[ns]` type series as intermediate data and then convert to `datetime.date` series, but the intermediate `datetime64[ns]` might cause an overflow even if the date is valid. ``` >>> import datetime >>> >>> t = [datetime.date(2262, 4, 12), datetime.date(2263, 4, 12)] >>> >>> df = spark.createDataFrame(t, 'date') >>> df.show() +----------+ | value| +----------+ |2262-04-12| |2263-04-12| +----------+ >>> >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> >>> df.toPandas() value 0 1677-09-21 1 1678-09-21 ``` We should avoid creating such intermediate data and create `datetime.date` series directly instead. ## How was this patch tested? Modified some tests to include the date which overflow caused by the intermediate conversion. Run tests with pyarrow 0.8, 0.10, 0.11, 0.12 in my local environment. Closes #23795 from ueshin/issues/SPARK-26887/date_as_object. Authored-by: Takuya UESHIN Signed-off-by: Hyukjin Kwon --- python/pyspark/serializers.py | 5 +- python/pyspark/sql/dataframe.py | 5 +- python/pyspark/sql/tests/test_arrow.py | 5 +- .../sql/tests/test_pandas_udf_scalar.py | 3 +- python/pyspark/sql/types.py | 54 ++++++++++++------- 5 files changed, 44 insertions(+), 28 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 1d170530d285..31984a2c7a1c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -312,10 +312,9 @@ def __init__(self, timezone, safecheck): def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ - _check_series_convert_date, _check_series_localize_timestamps + _arrow_column_to_pandas, _check_series_localize_timestamps - s = arrow_column.to_pandas() - s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type)) s = _check_series_localize_timestamps(s, self._timezone) return s diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a1056d0b787e..472d2969b3e1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2107,14 +2107,13 @@ def toPandas(self): # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. if use_arrow: try: - from pyspark.sql.types import _check_dataframe_convert_date, \ + from pyspark.sql.types import _arrow_table_to_pandas, \ _check_dataframe_localize_timestamps import pyarrow batches = self._collectAsArrow() if len(batches) > 0: table = pyarrow.Table.from_batches(batches) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) + pdf = _arrow_table_to_pandas(table, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 3ce6764278ce..d82da5cec983 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -68,7 +68,9 @@ def setUpClass(cls): (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)), + (u"d", 4, 40, 1.0, 8.0, Decimal("8.0"), + date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))] # TODO: remove version check once minimum pyarrow version is 0.10.0 if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): @@ -76,6 +78,7 @@ def setUpClass(cls): cls.data[0] = cls.data[0] + (bytearray(b"a"),) cls.data[1] = cls.data[1] + (bytearray(b"bb"),) cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + cls.data[3] = cls.data[3] + (bytearray(b"dddd"),) @classmethod def tearDownClass(cls): diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index f29ff11ab998..71bae9689a5f 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -349,7 +349,8 @@ def test_vectorized_udf_dates(self): data = [(0, date(1969, 1, 1),), (1, date(2012, 2, 2),), (2, None,), - (3, date(2100, 4, 4),)] + (3, date(2100, 4, 4),), + (4, date(2262, 4, 12),)] df = self.spark.createDataFrame(data, schema=schema) date_copy = pandas_udf(lambda t: t, returnType=DateType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f540954bcdb5..ad8a3e0b07a8 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1684,38 +1684,52 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _check_series_convert_date(series, data_type): - """ - Cast the series to datetime.date if it's a date type, otherwise returns the original series. +def _arrow_column_to_pandas(column, data_type): + """ Convert Arrow Column to pandas Series. - :param series: pandas.Series - :param data_type: a Spark data type for the series + :param series: pyarrow.lib.Column + :param data_type: a Spark data type for the column """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0") and type(data_type) == DateType: - return series.dt.date + # If the given column is a date type column, creates a series of datetime.date directly instead + # of creating datetime64[ns] as intermediate data to avoid overflow caused by datetime64[ns] + # type handling. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if type(data_type) == DateType: + return pd.Series(column.to_pylist(), name=column.name) + else: + return column.to_pandas() else: - return series + # Since Arrow 0.11.0, support date_as_object to return datetime.date instead of + # np.datetime64. + return column.to_pandas(date_as_object=True) -def _check_dataframe_convert_date(pdf, schema): - """ Correct date type value to use datetime.date. +def _arrow_table_to_pandas(table, schema): + """ Convert Arrow Table to pandas DataFrame. Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should use datetime.date to match the behavior with when Arrow optimization is disabled. - :param pdf: pandas.DataFrame - :param schema: a Spark schema of the pandas.DataFrame + :param table: pyarrow.lib.Table + :param schema: a Spark schema of the pyarrow.lib.Table """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0"): - for field in schema: - pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) - return pdf + # If the given table contains a date type column, use `_arrow_column_to_pandas` for pyarrow<0.11 + # or use `date_as_object` option for pyarrow>=0.11 to avoid creating datetime64[ns] as + # intermediate data. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if any(type(field.dataType) == DateType for field in schema): + return pd.concat([_arrow_column_to_pandas(column, field.dataType) + for column, field in zip(table.itercolumns(), schema)], axis=1) + else: + return table.to_pandas() + else: + return table.to_pandas(date_as_object=True) def _get_local_timezone(): From e8193ed14755a4cd5ab8db6b81e46de8701784fb Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 7 Mar 2019 08:52:24 -0800 Subject: [PATCH 2/4] [SPARK-23836][PYTHON] Add support for StructType return in Scalar Pandas UDF This change adds support for returning StructType from a scalar Pandas UDF, where the return value of the function is a pandas.DataFrame. Nested structs are not supported and an error will be raised, child types can be any other type currently supported. Added additional unit tests to `test_pandas_udf_scalar` Closes #23900 from BryanCutler/pyspark-support-scalar_udf-StructType-SPARK-23836. Authored-by: Bryan Cutler Signed-off-by: Bryan Cutler --- python/pyspark/serializers.py | 39 ++++++++- python/pyspark/sql/functions.py | 12 ++- python/pyspark/sql/session.py | 3 +- .../sql/tests/test_pandas_udf_grouped_map.py | 1 + .../sql/tests/test_pandas_udf_scalar.py | 81 ++++++++++++++++++- python/pyspark/sql/types.py | 8 +- python/pyspark/sql/udf.py | 5 +- python/pyspark/worker.py | 12 ++- 8 files changed, 149 insertions(+), 12 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 31984a2c7a1c..6c33e0eee1c4 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -66,6 +66,7 @@ else: import pickle protocol = 3 + basestring = unicode = str xrange = range from pyspark import cloudpickle @@ -245,7 +246,7 @@ def __repr__(self): return "ArrowStreamSerializer" -def _create_batch(series, timezone, safecheck): +def _create_batch(series, timezone, safecheck, assign_cols_by_name): """ Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. @@ -255,6 +256,7 @@ def _create_batch(series, timezone, safecheck): """ import decimal from distutils.version import LooseVersion + import pandas as pd import pyarrow as pa from pyspark.sql.types import _check_series_convert_timestamps_internal # Make input conform to [(series1, type1), (series2, type2), ...] @@ -296,7 +298,34 @@ def create_array(s, t): raise RuntimeError(error_msg % (s.dtype, t), e) return array - arrs = [create_array(s, t) for s, t in series] + arrs = [] + for s, t in series: + if t is not None and pa.types.is_struct(t): + if not isinstance(s, pd.DataFrame): + raise ValueError("A field of type StructType expects a pandas.DataFrame, " + "but got: %s" % str(type(s))) + + # Input partition and result pandas.DataFrame empty, make empty Arrays with struct + if len(s) == 0 and len(s.columns) == 0: + arrs_names = [(pa.array([], type=field.type), field.name) for field in t] + # Assign result columns by schema name if user labeled with strings + elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns): + arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t] + # Assign result columns by position + else: + arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) + for i, field in enumerate(t)] + + struct_arrs, struct_names = zip(*arrs_names) + + # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version + if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): + arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) + else: + arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) + else: + arrs.append(create_array(s, t)) + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) @@ -305,10 +334,11 @@ class ArrowStreamPandasSerializer(Serializer): Serializes Pandas.Series as Arrow data with Arrow streaming format. """ - def __init__(self, timezone, safecheck): + def __init__(self, timezone, safecheck, assign_cols_by_name): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ @@ -327,7 +357,8 @@ def dump_stream(self, iterator, stream): writer = None try: for series in iterator: - batch = _create_batch(series, self._timezone, self._safecheck) + batch = _create_batch(series, self._timezone, self._safecheck, + self._assign_cols_by_name) if writer is None: write_int(SpecialLengths.START_ARROW_STREAM, stream) writer = pa.RecordBatchStreamWriter(stream, batch.schema) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cac566c74cd9..584de7be33ca 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2872,8 +2872,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`. - :class:`MapType`, :class:`StructType` are currently not supported as output types. + :class:`MapType`, nested :class:`StructType` are currently not supported as output types. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and :meth:`pyspark.sql.DataFrame.select`. @@ -2898,6 +2899,15 @@ def pandas_udf(f=None, returnType=None, functionType=None): +----------+--------------+------------+ | 8| JOHN DOE| 22| +----------+--------------+------------+ + >>> @pandas_udf("first string, last string") # doctest: +SKIP + ... def split_expand(n): + ... return n.str.split(expand=True) + >>> df.select(split_expand("name")).show() # doctest: +SKIP + +------------------+ + |split_expand(name)| + +------------------+ + | [John, Doe]| + +------------------+ .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input column, but is the length of an internal batch used for each call to the function. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index bdf1701a5895..32a2c8a67252 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -557,8 +557,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): # Create Arrow record batches safecheck = self._wrapped._conf.arrowSafeTypeConversion() + col_by_name = True # col by name only applies to StructType columns, can't happen here batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], - timezone, safecheck) + timezone, safecheck, col_by_name) for pdf_slice in pdf_slices] # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index a0a25359d1e0..f7684d3fbcff 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -273,6 +273,7 @@ def test_unsupported_types(self): StructField('map', MapType(StringType(), IntegerType())), StructField('arr_ts', ArrayType(TimestampType())), StructField('null', NullType()), + StructField('struct', StructType([StructField('l', LongType())])), ] # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 71bae9689a5f..18c19199838a 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -23,13 +23,16 @@ import time import unittest +if sys.version >= '3': + unicode = str + from datetime import date, datetime from decimal import Decimal from distutils.version import LooseVersion from pyspark.rdd import PythonEvalType from pyspark.sql import Column -from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf +from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf from pyspark.sql.types import Row from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException @@ -265,6 +268,64 @@ def test_vectorized_udf_null_array(self): result = df.select(array_f(col('array'))) self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_struct_type(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('id', LongType()), + StructField('str', StringType())]) + + def func(id): + return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + + f = pandas_udf(func, returnType=return_type) + + expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) + .alias('struct')).collect() + + actual = df.select(f(col('id')).alias('struct')).collect() + self.assertEqual(expected, actual) + + g = pandas_udf(func, 'id: long, str: string') + actual = df.select(g(col('id')).alias('struct')).collect() + self.assertEqual(expected, actual) + + def test_vectorized_udf_struct_complex(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('ts', TimestampType()), + StructField('arr', ArrayType(LongType()))]) + + @pandas_udf(returnType=return_type) + def f(id): + return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)), + 'arr': id.apply(lambda i: [i, i + 1])}) + + actual = df.withColumn('f', f(col('id'))).collect() + for i, row in enumerate(actual): + id, f = row + self.assertEqual(i, id) + self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0]) + self.assertListEqual([i, i + 1], f[1]) + + def test_vectorized_udf_nested_struct(self): + nested_type = StructType([ + StructField('id', IntegerType()), + StructField('nested', StructType([ + StructField('foo', StringType()), + StructField('bar', FloatType()) + ])) + ]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Invalid returnType with scalar Pandas UDFs'): + pandas_udf(lambda x: x, returnType=nested_type) + def test_vectorized_udf_complex(self): df = self.spark.range(10).select( col('id').cast('int').alias('a'), @@ -331,6 +392,20 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_struct_with_empty_partition(self): + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\ + .withColumn('name', lit('John Doe')) + + @pandas_udf("first string, last string") + def split_expand(n): + return n.str.split(expand=True) + + result = df.select(split_expand('name')).collect() + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual('John', row[0]['first']) + self.assertEqual('Doe', row[0]['last']) + def test_vectorized_udf_varargs(self): df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) f = pandas_udf(lambda *v: v[0], LongType()) @@ -343,6 +418,10 @@ def test_vectorized_udf_unsupported_types(self): NotImplementedError, 'Invalid returnType.*scalar Pandas UDF.*MapType'): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'): + pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())]))) def test_vectorized_udf_dates(self): schema = StructType().add("idx", LongType()).add("date", DateType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ad8a3e0b07a8..086765382fd5 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1616,9 +1616,15 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: - if type(dt.elementType) == TimestampType: + if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) + elif type(dt) == StructType: + if any(type(field.dataType) == StructType for field in dt): + raise TypeError("Nested StructType not supported in conversion to Arrow") + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in dt] + arrow_type = pa.struct(fields) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index bd137a1a0268..20db0522ccf5 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -124,7 +124,7 @@ def returnType(self): elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: if isinstance(self._returnType_placeholder, StructType): try: - to_arrow_schema(self._returnType_placeholder) + to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid returnType with grouped map Pandas UDFs: " @@ -134,6 +134,9 @@ def returnType(self): "UDFs: returnType must be a StructType.") elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: try: + # StructType is not yet allowed as a return type, explicitly check here to fail fast + if isinstance(self._returnType_placeholder, StructType): + raise TypeError to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 01934a0e7275..0e9b6d665a36 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -39,7 +39,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type +from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -90,8 +90,9 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): + pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series" raise TypeError("Return type of the user-defined function should be " - "Pandas.Series, but is {}".format(type(result))) + "{}, but is {}".format(pd_type, type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) @@ -254,7 +255,12 @@ def read_udfs(pickleSer, infile, eval_type): timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' - ser = ArrowStreamPandasSerializer(timezone, safecheck) + # NOTE: this is duplicated from wrap_grouped_map_pandas_udf + assign_cols_by_name = runner_conf.get( + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ + .lower() == "true" + + ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) else: ser = BatchedSerializer(PickleSerializer(), 100) From 8d69b8c93f70924b9c015441745971441fa9d941 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 21 Mar 2019 17:44:51 +0900 Subject: [PATCH 3/4] [SPARK-27163][PYTHON] Cleanup and consolidate Pandas UDF functionality ## What changes were proposed in this pull request? This change is a cleanup and consolidation of 3 areas related to Pandas UDFs: 1) `ArrowStreamPandasSerializer` now inherits from `ArrowStreamSerializer` and uses the base class `dump_stream`, `load_stream` to create Arrow reader/writer and send Arrow record batches. `ArrowStreamPandasSerializer` makes the conversions to/from Pandas and converts to Arrow record batch iterators. This change removed duplicated creation of Arrow readers/writers. 2) `createDataFrame` with Arrow now uses `ArrowStreamPandasSerializer` instead of doing its own conversions from Pandas to Arrow and sending record batches through `ArrowStreamSerializer`. 3) Grouped Map UDFs now reuse existing logic in `ArrowStreamPandasSerializer` to send Pandas DataFrame results as a `StructType` instead of separating each column from the DataFrame. This makes the code a little more consistent with the Python worker, but does require that the returned StructType column is flattened out in `FlatMapGroupsInPandasExec` in Scala. ## How was this patch tested? Existing tests and ran tests with pyarrow 0.12.0 Closes #24095 from BryanCutler/arrow-refactor-cleanup-UDFs. Authored-by: Bryan Cutler Signed-off-by: Hyukjin Kwon --- python/pyspark/serializers.py | 218 ++++++++++-------- python/pyspark/sql/session.py | 42 ++-- python/pyspark/worker.py | 23 +- .../sql/execution/arrow/ArrowConverters.scala | 1 - .../python/FlatMapGroupsInPandasExec.scala | 12 +- 5 files changed, 161 insertions(+), 135 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 6c33e0eee1c4..9510e5e232ea 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -246,92 +246,13 @@ def __repr__(self): return "ArrowStreamSerializer" -def _create_batch(series, timezone, safecheck, assign_cols_by_name): +class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ - Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. + Serializes Pandas.Series as Arrow data with Arrow streaming format. - :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) :param timezone: A timezone to respect when handling timestamp values - :return: Arrow RecordBatch - """ - import decimal - from distutils.version import LooseVersion - import pandas as pd - import pyarrow as pa - from pyspark.sql.types import _check_series_convert_timestamps_internal - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or \ - (len(series) == 2 and isinstance(series[1], pa.DataType)): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - def create_array(s, t): - mask = s.isnull() - # Ensure timestamp series are in expected form for Spark internal representation - # TODO: maybe don't need None check anymore as of Arrow 0.9.1 - if t is not None and pa.types.is_timestamp(t): - s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) - # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 - return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) - elif t is not None and pa.types.is_string(t) and sys.version < '3': - # TODO: need decode before converting to Arrow in Python 2 - # TODO: don't need as of Arrow 0.9.1 - return pa.Array.from_pandas(s.apply( - lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) - elif t is not None and pa.types.is_decimal(t) and \ - LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. - return pa.Array.from_pandas(s.apply( - lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) - elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"): - # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. - return pa.Array.from_pandas(s, mask=mask, type=t) - - try: - array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck) - except pa.ArrowException as e: - error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ - "Array (%s). It can be caused by overflows or other unsafe " + \ - "conversions warned by Arrow. Arrow safe type check can be " + \ - "disabled by using SQL config " + \ - "`spark.sql.execution.pandas.arrowSafeTypeConversion`." - raise RuntimeError(error_msg % (s.dtype, t), e) - return array - - arrs = [] - for s, t in series: - if t is not None and pa.types.is_struct(t): - if not isinstance(s, pd.DataFrame): - raise ValueError("A field of type StructType expects a pandas.DataFrame, " - "but got: %s" % str(type(s))) - - # Input partition and result pandas.DataFrame empty, make empty Arrays with struct - if len(s) == 0 and len(s.columns) == 0: - arrs_names = [(pa.array([], type=field.type), field.name) for field in t] - # Assign result columns by schema name if user labeled with strings - elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns): - arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t] - # Assign result columns by position - else: - arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) - for i, field in enumerate(t)] - - struct_arrs, struct_names = zip(*arrs_names) - - # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version - if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): - arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) - else: - arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) - else: - arrs.append(create_array(s, t)) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) - - -class ArrowStreamPandasSerializer(Serializer): - """ - Serializes Pandas.Series as Arrow data with Arrow streaming format. + :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation + :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name """ def __init__(self, timezone, safecheck, assign_cols_by_name): @@ -348,39 +269,138 @@ def arrow_to_pandas(self, arrow_column): s = _check_series_localize_timestamps(s, self._timezone) return s + def _create_batch(self, series): + """ + Create an Arrow record batch from the given pandas.Series or list of Series, + with optional type. + + :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :return: Arrow RecordBatch + """ + import decimal + from distutils.version import LooseVersion + import pandas as pd + import pyarrow as pa + from pyspark.sql.types import _check_series_convert_timestamps_internal + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + def create_array(s, t): + mask = s.isnull() + # Ensure timestamp series are in expected form for Spark internal representation + # TODO: maybe don't need None check anymore as of Arrow 0.9.1 + if t is not None and pa.types.is_timestamp(t): + s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone) + # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 + return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + # TODO: don't need as of Arrow 0.9.1 + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) + elif t is not None and pa.types.is_decimal(t) and \ + LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. + return pa.Array.from_pandas(s.apply( + lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) + elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. + return pa.Array.from_pandas(s, mask=mask, type=t) + + try: + array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) + except pa.ArrowException as e: + error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ + "Array (%s). It can be caused by overflows or other unsafe " + \ + "conversions warned by Arrow. Arrow safe type check can be " + \ + "disabled by using SQL config " + \ + "`spark.sql.execution.pandas.arrowSafeTypeConversion`." + raise RuntimeError(error_msg % (s.dtype, t), e) + return array + + arrs = [] + for s, t in series: + if t is not None and pa.types.is_struct(t): + if not isinstance(s, pd.DataFrame): + raise ValueError("A field of type StructType expects a pandas.DataFrame, " + "but got: %s" % str(type(s))) + + # Input partition and result pandas.DataFrame empty, make empty Arrays with struct + if len(s) == 0 and len(s.columns) == 0: + arrs_names = [(pa.array([], type=field.type), field.name) for field in t] + # Assign result columns by schema name if user labeled with strings + elif self._assign_cols_by_name and any(isinstance(name, basestring) + for name in s.columns): + arrs_names = [(create_array(s[field.name], field.type), field.name) + for field in t] + # Assign result columns by position + else: + arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) + for i, field in enumerate(t)] + + struct_arrs, struct_names = zip(*arrs_names) + + # TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version + if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): + arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) + else: + arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) + else: + arrs.append(create_array(s, t)) + + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ - import pyarrow as pa - writer = None - try: - for series in iterator: - batch = _create_batch(series, self._timezone, self._safecheck, - self._assign_cols_by_name) - if writer is None: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - writer = pa.RecordBatchStreamWriter(stream, batch.schema) - writer.write_batch(batch) - finally: - if writer is not None: - writer.close() + batches = (self._create_batch(series) for series in iterator) + super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) import pyarrow as pa - reader = pa.ipc.open_stream(stream) - - for batch in reader: + for batch in batches: yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" +class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): + """ + Serializer used by Python worker to evaluate Pandas UDFs + """ + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def init_stream_yield_batches(): + should_write_start_length = True + for series in iterator: + batch = self._create_batch(series) + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) + + def __repr__(self): + return "ArrowStreamPandasUDFSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 32a2c8a67252..b11e0f3ff69d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -530,8 +530,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowStreamSerializer, _create_batch - from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from distutils.version import LooseVersion + from pyspark.serializers import ArrowStreamPandasSerializer + from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,6 +540,19 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): require_minimum_pyarrow_version() from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + import pyarrow as pa + + # Create the Spark schema from list of names passed in with Arrow types + if isinstance(schema, (list, tuple)): + if LooseVersion(pa.__version__) < LooseVersion("0.12.0"): + temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False) + arrow_schema = temp_batch.schema + else: + arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) + struct = StructType() + for name, field in zip(schema, arrow_schema): + struct.add(name, from_arrow_type(field.type), nullable=field.nullable) + schema = struct # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): @@ -555,23 +569,16 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) - # Create Arrow record batches - safecheck = self._wrapped._conf.arrowSafeTypeConversion() - col_by_name = True # col by name only applies to StructType columns, can't happen here - batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], - timezone, safecheck, col_by_name) - for pdf_slice in pdf_slices] - - # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) - if isinstance(schema, (list, tuple)): - struct = from_arrow_schema(batches[0].schema) - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct + # Create list of Arrow (columns, type) for serializer dump_stream + arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)] + for pdf_slice in pdf_slices] jsqlContext = self._wrapped._jsqlContext + safecheck = self._wrapped._conf.arrowSafeTypeConversion() + col_by_name = True # col by name only applies to StructType columns, can't happen here + ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name) + def reader_func(temp_filename): return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename) @@ -579,8 +586,7 @@ def create_RDD_server(): return self._jvm.ArrowRDDServer(jsqlContext) # Create Spark DataFrame from Arrow stream file, using one batch per partition - jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func, - create_RDD_server) + jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server) jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext) df = DataFrame(jdf, self._wrapped) df._schema = schema diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0e9b6d665a36..f59fb443b4db 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,7 +38,7 @@ from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowStreamPandasSerializer + BatchedSerializer, ArrowStreamPandasUDFSerializer from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -101,10 +101,7 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): - assign_cols_by_name = runner_conf.get( - "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true") - assign_cols_by_name = assign_cols_by_name.lower() == "true" +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd @@ -123,15 +120,9 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + return result - # Assign result columns by schema name if user labeled with strings, else use position - if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns): - return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] - else: - return [(result[result.columns[i]], to_arrow_type(field.dataType)) - for i, field in enumerate(return_type)] - - return wrapped + return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] def wrap_grouped_agg_pandas_udf(f, return_type): @@ -225,7 +216,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -255,12 +246,12 @@ def read_udfs(pickleSer, infile, eval_type): timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' - # NOTE: this is duplicated from wrap_grouped_map_pandas_udf + # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType assign_cols_by_name = runner_conf.get( "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" - ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) + ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name) else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 2bf6a58b5565..884dc8c6215f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -31,7 +31,6 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.network.util.JavaUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index e9cff1a5a200..ce755ffb7c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} /** * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] @@ -145,7 +146,16 @@ case class FlatMapGroupsInPandasExec( sessionLocalTimeZone, pythonRunnerConf).compute(grouped, context.partitionId(), context) - columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) } } } From adb3a017a7be3eaa422a9c0f6ee1c7ff1035915a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 25 Mar 2019 11:26:09 -0700 Subject: [PATCH 4/4] [SPARK-27240][PYTHON] Use pandas DataFrame for struct type argument in Scalar Pandas UDF. ## What changes were proposed in this pull request? Now that we support returning pandas DataFrame for struct type in Scalar Pandas UDF. If we chain another Pandas UDF after the Scalar Pandas UDF returning pandas DataFrame, the argument of the chained UDF will be pandas DataFrame, but currently we don't support pandas DataFrame as an argument of Scalar Pandas UDF. That means there is an inconsistency between the chained UDF and the single UDF. We should support taking pandas DataFrame for struct type argument in Scalar Pandas UDF to be consistent. Currently pyarrow >=0.11 is supported. ## How was this patch tested? Modified and added some tests. Closes #24177 from ueshin/issues/SPARK-27240/structtype_argument. Authored-by: Takuya UESHIN Signed-off-by: Bryan Cutler --- python/pyspark/serializers.py | 29 +++++++++++++--- .../sql/tests/test_pandas_udf_scalar.py | 33 +++++++++++++++++++ python/pyspark/sql/types.py | 10 ++++++ python/pyspark/worker.py | 6 +++- 4 files changed, 72 insertions(+), 6 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9510e5e232ea..dfe1aa14798f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -261,11 +261,10 @@ def __init__(self, timezone, safecheck, assign_cols_by_name): self._safecheck = safecheck self._assign_cols_by_name = assign_cols_by_name - def arrow_to_pandas(self, arrow_column): - from pyspark.sql.types import from_arrow_type, \ - _arrow_column_to_pandas, _check_series_localize_timestamps + def arrow_to_pandas(self, arrow_column, data_type): + from pyspark.sql.types import _arrow_column_to_pandas, _check_series_localize_timestamps - s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type)) + s = _arrow_column_to_pandas(arrow_column, data_type) s = _check_series_localize_timestamps(s, self._timezone) return s @@ -367,8 +366,10 @@ def load_stream(self, stream): """ batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) import pyarrow as pa + from pyspark.sql.types import from_arrow_type for batch in batches: - yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] + yield [self.arrow_to_pandas(c, from_arrow_type(c.type)) + for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" @@ -379,6 +380,24 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): Serializer used by Python worker to evaluate Pandas UDFs """ + def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False): + super(ArrowStreamPandasUDFSerializer, self) \ + .__init__(timezone, safecheck, assign_cols_by_name) + self._df_for_struct = df_for_struct + + def arrow_to_pandas(self, arrow_column, data_type): + from pyspark.sql.types import StructType, \ + _arrow_column_to_pandas, _check_dataframe_localize_timestamps + + if self._df_for_struct and type(data_type) == StructType: + import pandas as pd + series = [_arrow_column_to_pandas(column, field.dataType).rename(field.name) + for column, field in zip(arrow_column.flatten(), data_type)] + s = _check_dataframe_localize_timestamps(pd.concat(series, axis=1), self._timezone) + else: + s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column, data_type) + return s + def dump_stream(self, iterator, stream): """ Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 18c19199838a..5efcfd343013 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -270,6 +270,7 @@ def test_vectorized_udf_null_array(self): def test_vectorized_udf_struct_type(self): import pandas as pd + import pyarrow as pa df = self.spark.range(10) return_type = StructType([ @@ -291,6 +292,18 @@ def func(id): actual = df.select(g(col('id')).alias('struct')).collect() self.assertEqual(expected, actual) + struct_f = pandas_udf(lambda x: x, return_type) + actual = df.select(struct_f(struct(col('id'), col('id').cast('string').alias('str')))) + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + from py4j.protocol import Py4JJavaError + with self.assertRaisesRegexp( + Py4JJavaError, + 'Unsupported type in conversion from Arrow'): + self.assertEqual(expected, actual.collect()) + else: + self.assertEqual(expected, actual.collect()) + def test_vectorized_udf_struct_complex(self): import pandas as pd @@ -363,6 +376,26 @@ def test_vectorized_udf_chained(self): res = df.select(g(f(col('id')))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_chained_struct_type(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('id', LongType()), + StructField('str', StringType())]) + + @pandas_udf(return_type) + def f(id): + return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + + g = pandas_udf(lambda x: x, return_type) + + expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) + .alias('struct')).collect() + + actual = df.select(g(f(col('id'))).alias('struct')).collect() + self.assertEqual(expected, actual) + def test_vectorized_udf_wrong_return_type(self): with QuietTest(self.sc): with self.assertRaisesRegexp( diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 086765382fd5..3246e9ee31f5 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1677,6 +1677,16 @@ def from_arrow_type(at): if types.is_timestamp(at.value_type): raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) + elif types.is_struct(at): + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at) + + "\nPlease install pyarrow >= 0.10.0 for StructType support.") + if any(types.is_struct(field.type) for field in at): + raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at)) + return StructType( + [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) + for field in at]) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f59fb443b4db..478fdc081d35 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -251,7 +251,11 @@ def read_udfs(pickleSer, infile, eval_type): "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" - ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name) + # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of + # pandas Series. See SPARK-27240. + df_for_struct = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, + df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100)