-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23380][PYTHON] Make toPandas fallback to non-Arrow optimization if possible #20567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1941,12 +1941,24 @@ def toPandas(self): | |
| timezone = None | ||
|
|
||
| if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": | ||
| should_fall_back = False | ||
| try: | ||
| from pyspark.sql.types import _check_dataframe_convert_date, \ | ||
| _check_dataframe_localize_timestamps | ||
| from pyspark.sql.types import to_arrow_schema | ||
| from pyspark.sql.utils import require_minimum_pyarrow_version | ||
| import pyarrow | ||
| require_minimum_pyarrow_version() | ||
| # Check if its schema is convertible in Arrow format. | ||
| to_arrow_schema(self.schema) | ||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to catch more specific exceptions here? i.e.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, it might depend on which message we want to show. Will open another PR as discussed above. |
||
| # Fallback to convert to Pandas DataFrame without arrow if raise some exception | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this PR fall back to the original path if any exception occurs? E.g.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup. It does fall back for unsupported schema, PyArrow version mismatch and PyAarrow missing. Will add a note in PR description. |
||
| should_fall_back = True | ||
|
||
| warnings.warn( | ||
| "Arrow will not be used in toPandas: %s" % _exception_message(e)) | ||
|
|
||
| if not should_fall_back: | ||
| import pyarrow | ||
| from pyspark.sql.types import _check_dataframe_convert_date, \ | ||
| _check_dataframe_localize_timestamps | ||
|
|
||
| tables = self._collectAsArrow() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be in the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please see #20567 (comment). @ueshin raised a similar concern.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, we don't want to collect twice and you manually run a schema conversion to fallback in that case. I think there still might be some cases where the Arrow path could fail, like maybe if there were incompatible arrow versions (like using a possible future version of pyarrow with Java still at 0.8) but this should cover the most common cases, so seems fine to me. |
||
| if tables: | ||
| table = pyarrow.concat_tables(tables) | ||
|
|
@@ -1955,38 +1967,34 @@ def toPandas(self): | |
| return _check_dataframe_localize_timestamps(pdf, timezone) | ||
| else: | ||
| return pd.DataFrame.from_records([], columns=self.columns) | ||
| except ImportError as e: | ||
| msg = "note: pyarrow must be installed and available on calling Python process " \ | ||
| "if using spark.sql.execution.arrow.enabled=true" | ||
| raise ImportError("%s\n%s" % (_exception_message(e), msg)) | ||
| else: | ||
| pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) | ||
|
|
||
| dtype = {} | ||
| pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actual diff here is just |
||
|
|
||
| dtype = {} | ||
| for field in self.schema: | ||
| pandas_type = _to_corrected_pandas_type(field.dataType) | ||
| # SPARK-21766: if an integer field is nullable and has null values, it can be | ||
| # inferred by pandas as float column. Once we convert the column with NaN back | ||
| # to integer type e.g., np.int16, we will hit exception. So we use the inferred | ||
| # float type, not the corrected type from the schema in this case. | ||
| if pandas_type is not None and \ | ||
| not(isinstance(field.dataType, IntegralType) and field.nullable and | ||
| pdf[field.name].isnull().any()): | ||
| dtype[field.name] = pandas_type | ||
|
|
||
| for f, t in dtype.items(): | ||
| pdf[f] = pdf[f].astype(t, copy=False) | ||
|
|
||
| if timezone is None: | ||
| return pdf | ||
| else: | ||
| from pyspark.sql.types import _check_series_convert_timestamps_local_tz | ||
| for field in self.schema: | ||
| pandas_type = _to_corrected_pandas_type(field.dataType) | ||
| # SPARK-21766: if an integer field is nullable and has null values, it can be | ||
| # inferred by pandas as float column. Once we convert the column with NaN back | ||
| # to integer type e.g., np.int16, we will hit exception. So we use the inferred | ||
| # float type, not the corrected type from the schema in this case. | ||
| if pandas_type is not None and \ | ||
| not(isinstance(field.dataType, IntegralType) and field.nullable and | ||
| pdf[field.name].isnull().any()): | ||
| dtype[field.name] = pandas_type | ||
|
|
||
| for f, t in dtype.items(): | ||
| pdf[f] = pdf[f].astype(t, copy=False) | ||
|
|
||
| if timezone is None: | ||
| return pdf | ||
| else: | ||
| from pyspark.sql.types import _check_series_convert_timestamps_local_tz | ||
| for field in self.schema: | ||
| # TODO: handle nested timestamps, such as ArrayType(TimestampType())? | ||
| if isinstance(field.dataType, TimestampType): | ||
| pdf[field.name] = \ | ||
| _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) | ||
| return pdf | ||
| # TODO: handle nested timestamps, such as ArrayType(TimestampType())? | ||
| if isinstance(field.dataType, TimestampType): | ||
| pdf[field.name] = \ | ||
| _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) | ||
| return pdf | ||
|
|
||
| def _collectAsArrow(self): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ | |
| import datetime | ||
| import array | ||
| import ctypes | ||
| import warnings | ||
| import py4j | ||
|
|
||
| try: | ||
|
|
@@ -48,12 +49,12 @@ | |
| else: | ||
| import unittest | ||
|
|
||
| from pyspark.util import _exception_message | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add an empty line between this import and |
||
| _pandas_requirement_message = None | ||
| try: | ||
| from pyspark.sql.utils import require_minimum_pandas_version | ||
| require_minimum_pandas_version() | ||
| except ImportError as e: | ||
| from pyspark.util import _exception_message | ||
| # If Pandas version requirement is not satisfied, skip related tests. | ||
| _pandas_requirement_message = _exception_message(e) | ||
|
|
||
|
|
@@ -62,7 +63,6 @@ | |
| from pyspark.sql.utils import require_minimum_pyarrow_version | ||
| require_minimum_pyarrow_version() | ||
| except ImportError as e: | ||
| from pyspark.util import _exception_message | ||
| # If Arrow version requirement is not satisfied, skip related tests. | ||
| _pyarrow_requirement_message = _exception_message(e) | ||
|
|
||
|
|
@@ -3437,12 +3437,22 @@ def create_pandas_data_frame(self): | |
| data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) | ||
| return pd.DataFrame(data=data_dict) | ||
|
|
||
| def test_unsupported_datatype(self): | ||
| def test_toPandas_fallback(self): | ||
| import pandas as pd | ||
|
|
||
| schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) | ||
| df = self.spark.createDataFrame([(None,)], schema=schema) | ||
| df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, 'Unsupported data type'): | ||
| df.toPandas() | ||
| with warnings.catch_warnings(record=True) as warns: | ||
| pdf = df.toPandas() | ||
| # Catch and check the last UserWarning. | ||
| user_warns = [ | ||
| warn.message for warn in warns if isinstance(warn.message, UserWarning)] | ||
| self.assertTrue(len(user_warns) > 0) | ||
| self.assertTrue( | ||
| "Arrow will not be used in toPandas" in _exception_message(user_warns[-1])) | ||
|
|
||
| self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) | ||
|
|
||
| def test_null_conversion(self): | ||
| df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the main change.