Skip to content

Commit 4019ba7

Browse files
rxinmarkhamstra
authored andcommitted
Merge pull request alteryx#218 from JoshRosen/spark-970-pyspark-unicode-error
Fix UnicodeEncodeError in PySpark saveAsTextFile() (SPARK-970) This fixes [SPARK-970](https://spark-project.atlassian.net/browse/SPARK-970), an issue where PySpark's saveAsTextFile() could throw UnicodeEncodeError when called on an RDD of Unicode strings. Please merge this into master and branch-0.8. (cherry picked from commit 8a3475a) Signed-off-by: Reynold Xin <[email protected]>
1 parent 4701f48 commit 4019ba7

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

python/pyspark/rdd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,10 @@ def saveAsTextFile(self, path):
598598
'0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
599599
"""
600600
def func(split, iterator):
601-
return (str(x).encode("utf-8") for x in iterator)
601+
for x in iterator:
602+
if not isinstance(x, basestring):
603+
x = unicode(x)
604+
yield x.encode("utf-8")
602605
keyed = PipelinedRDD(self, func)
603606
keyed._bypass_serializer = True
604607
keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)

python/pyspark/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
Unit tests for PySpark; additional tests are implemented as doctests in
2020
individual modules.
2121
"""
22+
from fileinput import input
23+
from glob import glob
2224
import os
2325
import shutil
2426
import sys
@@ -137,6 +139,19 @@ def func():
137139
self.assertEqual("Hello World from inside a package!", UserClass().hello())
138140

139141

142+
class TestRDDFunctions(PySparkTestCase):
143+
144+
def test_save_as_textfile_with_unicode(self):
145+
# Regression test for SPARK-970
146+
x = u"\u00A1Hola, mundo!"
147+
data = self.sc.parallelize([x])
148+
tempFile = NamedTemporaryFile(delete=True)
149+
tempFile.close()
150+
data.saveAsTextFile(tempFile.name)
151+
raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
152+
self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))
153+
154+
140155
class TestIO(PySparkTestCase):
141156

142157
def test_stdout_redirection(self):

0 commit comments

Comments
 (0)