Skip to content

Commit 90559c0

Browse files
committed
[SPARK-21045][PYSPARK] Defensive check for exception info thrown by user.
1 parent d74fc6b commit 90559c0

File tree

3 files changed

+52
-3
lines changed

3 files changed

+52
-3
lines changed

python/pyspark/testing/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import struct
2020
import sys
21+
import threading
2122
import unittest
2223

2324
from pyspark import SparkContext, SparkConf
@@ -127,3 +128,18 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
127128
raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars)))
128129
else:
129130
return jars[0]
131+
132+
133+
class ExecThread(threading.Thread):
134+
""" A wrapper thread which stores exception info if any occurred.
135+
"""
136+
def __init__(self, target):
137+
self.target = target
138+
self.exception = None
139+
threading.Thread.__init__(self)
140+
141+
def run(self):
142+
try:
143+
self.target()
144+
except Exception as e: # captures any exceptions
145+
self.exception = e

python/pyspark/tests/test_worker.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- encoding: utf-8 -*-
12
#
23
# Licensed to the Apache Software Foundation (ASF) under one or more
34
# contributor license agreements. See the NOTICE file distributed with
@@ -28,7 +29,7 @@
2829

2930
from py4j.protocol import Py4JJavaError
3031

31-
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest
32+
from pyspark.testing.utils import ExecThread, ReusedPySparkTestCase, PySparkTestCase, QuietTest
3233

3334
if sys.version_info[0] >= 3:
3435
xrange = range
@@ -150,6 +151,28 @@ def test_with_different_versions_of_python(self):
150151
finally:
151152
self.sc.pythonVer = version
152153

154+
def test_python_exception_non_hanging(self):
155+
"""
156+
SPARK-21045: exceptions with no ascii encoding shall not hanging PySpark.
157+
"""
158+
def f():
159+
raise Exception("exception with 中 and \xd6\xd0")
160+
161+
def run():
162+
self.sc.parallelize([1]).map(lambda x: f()).count()
163+
164+
t = ExecThread(target=run)
165+
t.daemon = True
166+
t.start()
167+
t.join(10)
168+
self.assertFalse(t.isAlive(), "Spark should not be blocked")
169+
self.assertIsInstance(t.exception, Py4JJavaError)
170+
if sys.version_info.major < 3:
171+
# we have to use unicode here to avoid UnicodeDecodeError
172+
self.assertRegexpMatches(unicode(t.exception).encode("utf-8"), "exception with 中")
173+
else:
174+
self.assertRegexpMatches(str(t.exception), "exception with 中")
175+
153176

154177
class WorkerReuseTest(PySparkTestCase):
155178

python/pyspark/worker.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from pyspark.util import _get_argspec, fail_on_stopiteration
4545
from pyspark import shuffle
4646

47-
if sys.version >= '3':
47+
if sys.version_info.major >= 3:
4848
basestring = str
4949
else:
5050
from itertools import imap as map # use iterator map by default
@@ -598,8 +598,18 @@ def process():
598598
process()
599599
except Exception:
600600
try:
601+
exc_info = traceback.format_exc()
602+
if sys.version_info.major < 3:
603+
if isinstance(exc_info, unicode):
604+
exc_info = exc_info.encode("utf-8")
605+
else:
606+
# exc_info may contains other encoding bytes, replace the invalid byte and
607+
# convert it back to utf-8 again
608+
exc_info = exc_info.decode("utf-8", "replace").encode("utf-8")
609+
else:
610+
exc_info = exc_info.encode("utf-8")
601611
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
602-
write_with_length(traceback.format_exc().encode("utf-8"), outfile)
612+
write_with_length(exc_info, outfile)
603613
except IOError:
604614
# JVM close the socket
605615
pass

0 commit comments

Comments
 (0)