diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 996b7dd59ce9..83afafdd8b13 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -18,6 +18,9 @@ import py4j import sys +if sys.version_info.major >= 3: + unicode = str + class CapturedException(Exception): def __init__(self, desc, stackTrace, cause=None): diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 18fde17f4a06..ccbe21f3a6f3 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -1,3 +1,4 @@ +# -*- encoding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -150,6 +151,20 @@ def test_with_different_versions_of_python(self): finally: self.sc.pythonVer = version + def test_python_exception_non_hanging(self): + # SPARK-21045: exceptions with no ascii encoding shall not hanging PySpark. + try: + def f(): + raise Exception("exception with 中 and \xd6\xd0") + + self.sc.parallelize([1]).map(lambda x: f()).count() + except Py4JJavaError as e: + if sys.version_info.major < 3: + # we have to use unicode here to avoid UnicodeDecodeError + self.assertRegexpMatches(unicode(e).encode("utf-8"), "exception with 中") + else: + self.assertRegexpMatches(str(e), "exception with 中") + class WorkerReuseTest(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 086202de2c68..698193d6bdd8 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -598,8 +598,15 @@ def process(): process() except Exception: try: + exc_info = traceback.format_exc() + if isinstance(exc_info, bytes): + # exc_info may contains other encoding bytes, replace the invalid bytes and convert + # it back to utf-8 again + exc_info = exc_info.decode("utf-8", "replace").encode("utf-8") + else: + exc_info = exc_info.encode("utf-8") write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc().encode("utf-8"), outfile) + write_with_length(exc_info, outfile) except IOError: # JVM close the socket pass