Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 41 additions & 33 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,12 +1941,24 @@ def toPandas(self):
timezone = None

if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
should_fall_back = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the main change.

try:
from pyspark.sql.types import _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
from pyspark.sql.types import to_arrow_schema
from pyspark.sql.utils import require_minimum_pyarrow_version
import pyarrow
require_minimum_pyarrow_version()
# Check if its schema is convertible in Arrow format.
to_arrow_schema(self.schema)
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to catch more specific exceptions here? i.e. TypeError and ImportError?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it might depend on which message we want to show. Will open another PR as discussed above.

# Fallback to convert to Pandas DataFrame without arrow if raise some exception
Copy link
Member

@kiszk kiszk Feb 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this PR fall back to the original path if any exception occurs? E.g. ImportError happens while the current code throws an exception with the message?
Would it be good to note this change in the description, too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. It does fall back for unsupported schema, PyArrow version mismatch and PyAarrow missing. Will add a note in PR description.

should_fall_back = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should_fall_back -> should_fallback other places below too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup.

warnings.warn(
"Arrow will not be used in toPandas: %s" % _exception_message(e))

if not should_fall_back:
import pyarrow
from pyspark.sql.types import _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps

tables = self._collectAsArrow()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be in the try block?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see #20567 (comment). @ueshin raised a similar concern.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we don't want to collect twice and you manually run a schema conversion to fallback in that case. I think there still might be some cases where the Arrow path could fail, like maybe if there were incompatible arrow versions (like using a possible future version of pyarrow with Java still at 0.8) but this should cover the most common cases, so seems fine to me.

if tables:
table = pyarrow.concat_tables(tables)
Expand All @@ -1955,38 +1967,34 @@ def toPandas(self):
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
except ImportError as e:
msg = "note: pyarrow must be installed and available on calling Python process " \
"if using spark.sql.execution.arrow.enabled=true"
raise ImportError("%s\n%s" % (_exception_message(e), msg))
else:
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)

dtype = {}
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual diff here is just else:. It was removed and it fixes the indentation.


dtype = {}
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null values, it can be
# inferred by pandas as float column. Once we convert the column with NaN back
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and field.nullable and
pdf[field.name].isnull().any()):
dtype[field.name] = pandas_type

for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)

if timezone is None:
return pdf
else:
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null values, it can be
# inferred by pandas as float column. Once we convert the column with NaN back
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and field.nullable and
pdf[field.name].isnull().any()):
dtype[field.name] = pandas_type

for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)

if timezone is None:
return pdf
else:
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
for field in self.schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
return pdf
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
return pdf

def _collectAsArrow(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \
_parse_datatype_string
from pyspark.sql.utils import install_exception_handler
from pyspark.util import _exception_message

__all__ = ["SparkSession"]

Expand Down Expand Up @@ -666,8 +667,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
except Exception as e:
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
# Fallback to create DataFrame without arrow if raise some exception
warnings.warn(
"Arrow will not be used in createDataFrame: %s" % _exception_message(e))
data = self._convert_from_pandas(data, schema, timezone)

if isinstance(schema, StructType):
Expand Down
22 changes: 16 additions & 6 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import datetime
import array
import ctypes
import warnings
import py4j

try:
Expand All @@ -48,12 +49,12 @@
else:
import unittest

from pyspark.util import _exception_message
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add an empty line between this import and _pandas_requirement_message line.

_pandas_requirement_message = None
try:
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
except ImportError as e:
from pyspark.util import _exception_message
# If Pandas version requirement is not satisfied, skip related tests.
_pandas_requirement_message = _exception_message(e)

Expand All @@ -62,7 +63,6 @@
from pyspark.sql.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
except ImportError as e:
from pyspark.util import _exception_message
# If Arrow version requirement is not satisfied, skip related tests.
_pyarrow_requirement_message = _exception_message(e)

Expand Down Expand Up @@ -3437,12 +3437,22 @@ def create_pandas_data_frame(self):
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)

def test_unsupported_datatype(self):
def test_toPandas_fallback(self):
import pandas as pd

schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.toPandas()
with warnings.catch_warnings(record=True) as warns:
pdf = df.toPandas()
# Catch and check the last UserWarning.
user_warns = [
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Arrow will not be used in toPandas" in _exception_message(user_warns[-1]))

self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))

def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
Expand Down