Skip to content

Commit 71df047

Browse files
committed
Add emptyRDD to pyspark and fix the issue when calling sum on an empty RDD
1 parent 4c5889e commit 71df047

File tree

4 files changed

+20
-1
lines changed

4 files changed

+20
-1
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,11 @@ private[spark] object PythonRDD extends Logging {
425425
iter.foreach(write)
426426
}
427427

428+
/** Create an RDD that has no partitions or elements. */
429+
def emptyRDD[T](sc: JavaSparkContext): JavaRDD[Array[Byte]] = {
430+
sc.emptyRDD[Array[Byte]]
431+
}
432+
428433
/**
429434
* Create an RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]],
430435
* key and value class.

python/pyspark/context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,12 @@ def stop(self):
324324
with SparkContext._lock:
325325
SparkContext._active_spark_context = None
326326

327+
def emptyRDD(self):
328+
"""
329+
Create an RDD that has no partitions or elements.
330+
"""
331+
return RDD(self._jsc.emptyRDD(), self, NoOpSerializer())
332+
327333
def range(self, start, end=None, step=1, numSlices=None):
328334
"""
329335
Create a new RDD of int containing elements from `start` to `end`

python/pyspark/rdd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ def sum(self):
960960
>>> sc.parallelize([1.0, 2.0, 3.0]).sum()
961961
6.0
962962
"""
963-
return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
963+
return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
964964

965965
def count(self):
966966
"""

python/pyspark/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,14 @@ def test_id(self):
458458
self.assertEqual(id + 1, id2)
459459
self.assertEqual(id2, rdd2.id())
460460

461+
def test_empty_rdd(self):
462+
rdd = self.sc.emptyRDD()
463+
self.assertTrue(rdd.isEmpty())
464+
465+
def test_sum(self):
466+
self.assertEqual(0, self.sc.emptyRDD().sum())
467+
self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
468+
461469
def test_save_as_textfile_with_unicode(self):
462470
# Regression test for SPARK-970
463471
x = u"\u00A1Hola, mundo!"

0 commit comments

Comments
 (0)