diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 70db4bbe4cbc..c3bcaf9f3de2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -687,12 +687,14 @@ def groupBy(self, f, numPartitions=None): return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) @ignore_unicode_prefix - def pipe(self, command, env={}): + def pipe(self, command, env={}, checkCode=False): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() [u'1', u'2', u'', u'3'] + + :param checkCode: whether or not to check the return value of the shell command. """ def func(iterator): pipe = Popen( @@ -704,7 +706,17 @@ def pipe_objs(out): out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() - return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b'')) + + def check_return_code(): + pipe.wait() + if checkCode and pipe.returncode: + raise Exception("Pipe function `%s' exited " + "with error code %d" % (command, pipe.returncode)) + else: + for i in range(0): + yield i + return (x.rstrip(b'\n').decode('utf-8') for x in + chain(iter(pipe.stdout.readline, b''), check_return_code())) return self.mapPartitions(func) def foreach(self, f): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index d8e319994cc9..46368c20d44b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -874,6 +874,18 @@ def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): for size in sizes: self.assertGreater(size, 0) + def test_pipe_functions(self): + data = ['1', '2', '3'] + rdd = self.sc.parallelize(data) + with QuietTest(self.sc): + self.assertEqual([], rdd.pipe('cc').collect()) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) + result = rdd.pipe('cat').collect() + result.sort() + [self.assertEqual(x, y) for x, y in zip(data, result)] + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) + self.assertEqual([], rdd.pipe('grep 4').collect()) + class ProfilerTests(PySparkTestCase):