Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improved doc, error message and code style
  • Loading branch information
e-dorigatti committed May 24, 2018
commit d739eea9e8ed07dad9dd9b1a795ff21e8f915694
35 changes: 17 additions & 18 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pyspark.shuffle import Aggregator, ExternalMerger, \
get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.util import fail_on_StopIteration
from pyspark.util import fail_on_stopiteration


__all__ = ["RDD"]
Expand Down Expand Up @@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
return map(fail_on_StopIteration(f), iterator)
return map(fail_on_stopiteration(f), iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def flatMap(self, f, preservesPartitioning=False):
Expand All @@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
return chain.from_iterable(map(fail_on_StopIteration(f), iterator))
return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def mapPartitions(self, f, preservesPartitioning=False):
Expand Down Expand Up @@ -411,7 +411,7 @@ def filter(self, f):
[2, 4]
"""
def func(iterator):
return filter(fail_on_StopIteration(f), iterator)
return filter(fail_on_stopiteration(f), iterator)
return self.mapPartitions(func, True)

def distinct(self, numPartitions=None):
Expand Down Expand Up @@ -792,11 +792,11 @@ def foreach(self, f):
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
safe_f = fail_on_StopIteration(f)
f = fail_on_stopiteration(f)

def processPartition(iterator):
for x in iterator:
safe_f(x)
f(x)
return iter([])
self.mapPartitions(processPartition).count() # Force evaluation

Expand Down Expand Up @@ -843,15 +843,15 @@ def reduce(self, f):
...
ValueError: Can not reduce() empty RDD
"""
safe_f = fail_on_StopIteration(f)
f = fail_on_stopiteration(f)

def func(iterator):
iterator = iter(iterator)
try:
initial = next(iterator)
except StopIteration:
return
yield reduce(safe_f, iterator, initial)
yield reduce(f, iterator, initial)

vals = self.mapPartitions(func).collect()
if vals:
Expand Down Expand Up @@ -916,12 +916,12 @@ def fold(self, zeroValue, op):
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
safe_op = fail_on_StopIteration(op)
op = fail_on_stopiteration(op)

def func(iterator):
acc = zeroValue
for obj in iterator:
acc = safe_op(acc, obj)
acc = op(acc, obj)
yield acc
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
Expand Down Expand Up @@ -950,19 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp):
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
(0, 0)
"""
safe_seqOp = fail_on_StopIteration(seqOp)
safe_combOp = fail_on_StopIteration(combOp)
seqOp = fail_on_stopiteration(seqOp)
combOp = fail_on_stopiteration(combOp)

def func(iterator):
acc = zeroValue
for obj in iterator:
acc = safe_seqOp(acc, obj)
acc = seqOp(acc, obj)
yield acc
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
# to the final reduce call
vals = self.mapPartitions(func).collect()
return reduce(safe_combOp, vals, zeroValue)
return reduce(combOp, vals, zeroValue)

def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
"""
Expand Down Expand Up @@ -1646,17 +1646,17 @@ def reduceByKeyLocally(self, func):
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
safe_func = fail_on_StopIteration(func)
func = fail_on_stopiteration(func)

def reducePartition(iterator):
m = {}
for k, v in iterator:
m[k] = safe_func(m[k], v) if k in m else v
m[k] = func(m[k], v) if k in m else v
yield m

def mergeMaps(m1, m2):
for k, v in m2.items():
m1[k] = safe_func(m1[k], v) if k in m1 else v
m1[k] = func(m1[k], v) if k in m1 else v
return m1
return self.mapPartitions(reducePartition).reduce(mergeMaps)

Expand Down Expand Up @@ -1858,7 +1858,6 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
>>> sorted(x.combineByKey(to_list, append, extend).collect())
[('a', [1, 2]), ('b', [1])]
"""

if numPartitions is None:
numPartitions = self._defaultReducePartitions()

Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
CompressedSerializer, AutoBatchedSerializer
from pyspark.util import fail_on_StopIteration
from pyspark.util import fail_on_stopiteration


try:
Expand Down Expand Up @@ -95,9 +95,9 @@ class Aggregator(object):
"""

def __init__(self, createCombiner, mergeValue, mergeCombiners):
self.createCombiner = fail_on_StopIteration(createCombiner)
self.mergeValue = fail_on_StopIteration(mergeValue)
self.mergeCombiners = fail_on_StopIteration(mergeCombiners)
self.createCombiner = fail_on_stopiteration(createCombiner)
self.mergeValue = fail_on_stopiteration(mergeValue)
self.mergeCombiners = fail_on_stopiteration(mergeCombiners)


class SimpleAggregator(Aggregator):
Expand Down
9 changes: 5 additions & 4 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ def majorMinorVersion(sparkVersion):
" version numbers.")


def fail_on_StopIteration(f):
""" wraps f to make it safe (= does not lead to data loss) to use inside a for loop
make StopIteration's raised inside f explicit
def fail_on_stopiteration(f):
"""
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    """
    Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
    prevents silent loss of data when 'f' is used in a for loop
    """

per PEP 8

def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except StopIteration as exc:
raise RuntimeError('StopIteration in client code', exc)
raise RuntimeError("Caught StopIteration thrown from user's code; failing the task", exc)

return wrapper

Expand Down