Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix pandas_udf with return type StringType() to handle str type prope…
…rly.
  • Loading branch information
ueshin committed Feb 5, 2018
commit 47b88734b91a7f9a4335bc3c667640eb4600b8e1
3 changes: 3 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def create_array(s, t):
s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
elif t is not None and pa.types.is_string(t) and sys.version < '3':
# TODO: need decode before converting to Arrow in Python 2
return pa.Array.from_pandas(s.str.decode('utf-8'), mask=mask, type=t)
Copy link
Member

@HyukjinKwon HyukjinKwon Feb 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin, actually, how about s.apply(lambda v: v.decode("utf-8") if isinstance(v, str) else v) to allow non-ascii encodable unicodes too like u"아"? I was worried of performance but I ran a simple perf test vs s.str.decode('utf-8') for sure. Seems actually fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I'll take it. Thanks!

return pa.Array.from_pandas(s, mask=mask, type=t)

arrs = [create_array(s, t) for s, t in series]
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3920,6 +3920,14 @@ def test_vectorized_udf_null_string(self):
res = df.select(str_f(col('str')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_string_in_udf(self):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
df = self.spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), StringType())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal. How about pd.Series(map(str, x))?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'll take it.

res = df.select(str_f(col('id')))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about variable names 'expected' and 'actual'?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll update it.

self.assertEquals(df.select(col('id').cast('string')).collect(), res.collect())

def test_vectorized_udf_datatype_string(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10).select(
Expand Down