Skip to content

Commit a511544

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType() to handle str type properly in Python 2.
## What changes were proposed in this pull request? In Python 2, when `pandas_udf` tries to return string type value created in the udf with `".."`, the execution fails. E.g., ```python from pyspark.sql.functions import pandas_udf, col import pandas as pd df = spark.range(10) str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string") df.select(str_f(col('id'))).show() ``` raises the following exception: ``` ... java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType at scala.Predef$.assert(Predef.scala:170) at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:93) ... ``` Seems like pyarrow ignores `type` parameter for `pa.Array.from_pandas()` and consider it as binary type when the type is string type and the string values are `str` instead of `unicode` in Python 2. This pr adds a workaround for the case. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN <ueshin@databricks.com> Closes #20507 from ueshin/issues/SPARK-23334. (cherry picked from commit 63c5bf1) Signed-off-by: hyukjinkwon <gurwls223@gmail.com>
1 parent 4493303 commit a511544

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

python/pyspark/serializers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def create_array(s, t):
230230
s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
231231
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
232232
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
233+
elif t is not None and pa.types.is_string(t) and sys.version < '3':
234+
# TODO: need decode before converting to Arrow in Python 2
235+
return pa.Array.from_pandas(s.apply(
236+
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
233237
return pa.Array.from_pandas(s, mask=mask, type=t)
234238

235239
arrs = [create_array(s, t) for s, t in series]

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3896,6 +3896,15 @@ def test_vectorized_udf_null_string(self):
38963896
res = df.select(str_f(col('str')))
38973897
self.assertEquals(df.collect(), res.collect())
38983898

3899+
def test_vectorized_udf_string_in_udf(self):
3900+
from pyspark.sql.functions import pandas_udf, col
3901+
import pandas as pd
3902+
df = self.spark.range(10)
3903+
str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType())
3904+
actual = df.select(str_f(col('id')))
3905+
expected = df.select(col('id').cast('string'))
3906+
self.assertEquals(expected.collect(), actual.collect())
3907+
38993908
def test_vectorized_udf_datatype_string(self):
39003909
from pyspark.sql.functions import pandas_udf, col
39013910
df = self.spark.range(10).select(

0 commit comments

Comments
 (0)