-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-23754][Python] Re-raising StopIteration in client code #21383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ec7854a
fddd031
ee54924
d739eea
f0f80ed
d59f0d5
b0af18e
167a75b
90b064d
75316af
026ecdd
f7b53c2
8fac2a8
5b5570b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,8 +49,9 @@ | |
| from pyspark.storagelevel import StorageLevel | ||
| from pyspark.resultiterable import ResultIterable | ||
| from pyspark.shuffle import Aggregator, ExternalMerger, \ | ||
| get_used_memory, ExternalSorter, ExternalGroupBy, safe_iter | ||
| get_used_memory, ExternalSorter, ExternalGroupBy | ||
| from pyspark.traceback_utils import SCCallSiteSync | ||
| from pyspark.util import fail_on_StopIteration | ||
|
|
||
|
|
||
| __all__ = ["RDD"] | ||
|
|
@@ -173,7 +174,6 @@ def ignore_unicode_prefix(f): | |
| return f | ||
|
|
||
|
|
||
|
|
||
| class Partitioner(object): | ||
| def __init__(self, numPartitions, partitionFunc): | ||
| self.numPartitions = numPartitions | ||
|
|
@@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False): | |
| [('a', 1), ('b', 1), ('c', 1)] | ||
| """ | ||
| def func(_, iterator): | ||
| return map(safe_iter(f), iterator) | ||
| return map(fail_on_StopIteration(f), iterator) | ||
| return self.mapPartitionsWithIndex(func, preservesPartitioning) | ||
|
|
||
| def flatMap(self, f, preservesPartitioning=False): | ||
|
|
@@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): | |
| [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] | ||
| """ | ||
| def func(s, iterator): | ||
| return chain.from_iterable(map(safe_iter(f), iterator)) | ||
| return chain.from_iterable(map(fail_on_StopIteration(f), iterator)) | ||
| return self.mapPartitionsWithIndex(func, preservesPartitioning) | ||
|
|
||
| def mapPartitions(self, f, preservesPartitioning=False): | ||
|
|
@@ -411,7 +411,7 @@ def filter(self, f): | |
| [2, 4] | ||
| """ | ||
| def func(iterator): | ||
| return filter(safe_iter(f), iterator) | ||
| return filter(fail_on_StopIteration(f), iterator) | ||
| return self.mapPartitions(func, True) | ||
|
|
||
| def distinct(self, numPartitions=None): | ||
|
|
@@ -792,7 +792,7 @@ def foreach(self, f): | |
| >>> def f(x): print(x) | ||
| >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) | ||
| """ | ||
| safe_f = safe_iter(f) | ||
| safe_f = fail_on_StopIteration(f) | ||
|
||
|
|
||
| def processPartition(iterator): | ||
| for x in iterator: | ||
|
|
@@ -843,7 +843,7 @@ def reduce(self, f): | |
| ... | ||
| ValueError: Can not reduce() empty RDD | ||
| """ | ||
| safe_f = safe_iter(f) | ||
| safe_f = fail_on_StopIteration(f) | ||
|
|
||
| def func(iterator): | ||
| iterator = iter(iterator) | ||
|
|
@@ -916,7 +916,7 @@ def fold(self, zeroValue, op): | |
| >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) | ||
| 15 | ||
| """ | ||
| safe_op = safe_iter(op) | ||
| safe_op = fail_on_StopIteration(op) | ||
|
|
||
| def func(iterator): | ||
| acc = zeroValue | ||
|
|
@@ -950,8 +950,8 @@ def aggregate(self, zeroValue, seqOp, combOp): | |
| >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) | ||
| (0, 0) | ||
| """ | ||
| safe_seqOp = safe_iter(seqOp) | ||
| safe_combOp = safe_iter(combOp) | ||
| safe_seqOp = fail_on_StopIteration(seqOp) | ||
| safe_combOp = fail_on_StopIteration(combOp) | ||
|
|
||
| def func(iterator): | ||
| acc = zeroValue | ||
|
|
@@ -1646,7 +1646,7 @@ def reduceByKeyLocally(self, func): | |
| >>> sorted(rdd.reduceByKeyLocally(add).items()) | ||
| [('a', 2), ('b', 1)] | ||
| """ | ||
| safe_func = safe_iter(func) | ||
| safe_func = fail_on_StopIteration(func) | ||
|
|
||
| def reducePartition(iterator): | ||
| m = {} | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,6 +89,19 @@ def majorMinorVersion(sparkVersion): | |
| " version numbers.") | ||
|
|
||
|
|
||
| def fail_on_StopIteration(f): | ||
| """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop | ||
|
||
| make StopIteration's raised inside f explicit | ||
| """ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. per PEP 8 |
||
| def wrapper(*args, **kwargs): | ||
| try: | ||
| return f(*args, **kwargs) | ||
| except StopIteration as exc: | ||
| raise RuntimeError('StopIteration in client code', exc) | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import doctest | ||
| (failure_count, test_count) = doctest.testmod() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would name it
fail_on_stop_iterationorfail_on_stopiterationper PEP 8.