Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
moved safe_iter to util module and more descriptive name
  • Loading branch information
e-dorigatti committed May 22, 2018
commit fddd031bbe4dda108739169f0a27eacae8f33099
22 changes: 11 additions & 11 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, ExternalMerger, \
get_used_memory, ExternalSorter, ExternalGroupBy, safe_iter
get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.util import fail_on_StopIteration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would name it fail_on_stop_iteration or fail_on_stopiteration per PEP 8.



__all__ = ["RDD"]
Expand Down Expand Up @@ -173,7 +174,6 @@ def ignore_unicode_prefix(f):
return f



class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
Expand Down Expand Up @@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
return map(safe_iter(f), iterator)
return map(fail_on_StopIteration(f), iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def flatMap(self, f, preservesPartitioning=False):
Expand All @@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
return chain.from_iterable(map(safe_iter(f), iterator))
return chain.from_iterable(map(fail_on_StopIteration(f), iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def mapPartitions(self, f, preservesPartitioning=False):
Expand Down Expand Up @@ -411,7 +411,7 @@ def filter(self, f):
[2, 4]
"""
def func(iterator):
return filter(safe_iter(f), iterator)
return filter(fail_on_StopIteration(f), iterator)
return self.mapPartitions(func, True)

def distinct(self, numPartitions=None):
Expand Down Expand Up @@ -792,7 +792,7 @@ def foreach(self, f):
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
safe_f = safe_iter(f)
safe_f = fail_on_StopIteration(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe prefix doesn't imply why it's safe though .. I would just name it like fail_on_stopiteration_f or feel free to another name if you have a good one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im okay with safe as is too if you feel strongly.


def processPartition(iterator):
for x in iterator:
Expand Down Expand Up @@ -843,7 +843,7 @@ def reduce(self, f):
...
ValueError: Can not reduce() empty RDD
"""
safe_f = safe_iter(f)
safe_f = fail_on_StopIteration(f)

def func(iterator):
iterator = iter(iterator)
Expand Down Expand Up @@ -916,7 +916,7 @@ def fold(self, zeroValue, op):
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
safe_op = safe_iter(op)
safe_op = fail_on_StopIteration(op)

def func(iterator):
acc = zeroValue
Expand Down Expand Up @@ -950,8 +950,8 @@ def aggregate(self, zeroValue, seqOp, combOp):
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
(0, 0)
"""
safe_seqOp = safe_iter(seqOp)
safe_combOp = safe_iter(combOp)
safe_seqOp = fail_on_StopIteration(seqOp)
safe_combOp = fail_on_StopIteration(combOp)

def func(iterator):
acc = zeroValue
Expand Down Expand Up @@ -1646,7 +1646,7 @@ def reduceByKeyLocally(self, func):
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
safe_func = safe_iter(func)
safe_func = fail_on_StopIteration(func)

def reducePartition(iterator):
m = {}
Expand Down
20 changes: 4 additions & 16 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
CompressedSerializer, AutoBatchedSerializer
from pyspark.util import fail_on_StopIteration


try:
Expand Down Expand Up @@ -67,19 +68,6 @@ def get_used_memory():
return 0


def safe_iter(f):
""" wraps f to make it safe (= does not lead to data loss) to use inside a for loop
make StopIteration's raised inside f explicit
"""
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except StopIteration as exc:
raise RuntimeError('StopIteration in client code', exc)

return wrapper


def _get_local_dirs(sub):
""" Get all the directories """
path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp")
Expand Down Expand Up @@ -107,9 +95,9 @@ class Aggregator(object):
"""

def __init__(self, createCombiner, mergeValue, mergeCombiners):
self.createCombiner = safe_iter(createCombiner)
self.mergeValue = safe_iter(mergeValue)
self.mergeCombiners = safe_iter(mergeCombiners)
self.createCombiner = fail_on_StopIteration(createCombiner)
self.mergeValue = fail_on_StopIteration(mergeValue)
self.mergeCombiners = fail_on_StopIteration(mergeCombiners)


class SimpleAggregator(Aggregator):
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ def majorMinorVersion(sparkVersion):
" version numbers.")


def fail_on_StopIteration(f):
""" wraps f to make it safe (= does not lead to data loss) to use inside a for loop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big deal at all but wraps -> Wraps while we are here.

Copy link
Member

@HyukjinKwon HyukjinKwon May 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like Wraps the input function to fail on StopIteration by RuntimeError to prevent data loss silently ... blabla?

make StopIteration's raised inside f explicit
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    """
    Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
    prevents silent loss of data when 'f' is used in a for loop
    """

per PEP 8

def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except StopIteration as exc:
raise RuntimeError('StopIteration in client code', exc)

return wrapper


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down