-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2871] [PySpark] add RDD.lookup(key) #2093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
eb1305d
be0e8ba
0f1bce8
c6390ea
2871b80
1789cd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
refactor
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -131,76 +131,6 @@ def __exit__(self, type, value, tb): | |
| self._context._jsc.setCallSite(None) | ||
|
|
||
|
|
||
| class MaxHeapQ(object): | ||
|
|
||
| """ | ||
| An implementation of MaxHeap. | ||
|
|
||
| >>> import pyspark.rdd | ||
| >>> heap = pyspark.rdd.MaxHeapQ(5) | ||
| >>> [heap.insert(i) for i in range(10)] | ||
| [None, None, None, None, None, None, None, None, None, None] | ||
| >>> sorted(heap.getElements()) | ||
| [0, 1, 2, 3, 4] | ||
| >>> heap = pyspark.rdd.MaxHeapQ(5) | ||
| >>> [heap.insert(i) for i in range(9, -1, -1)] | ||
| [None, None, None, None, None, None, None, None, None, None] | ||
| >>> sorted(heap.getElements()) | ||
| [0, 1, 2, 3, 4] | ||
| >>> heap = pyspark.rdd.MaxHeapQ(1) | ||
| >>> [heap.insert(i) for i in range(9, -1, -1)] | ||
| [None, None, None, None, None, None, None, None, None, None] | ||
| >>> heap.getElements() | ||
| [0] | ||
| """ | ||
|
|
||
| def __init__(self, maxsize): | ||
| # We start from q[1], so its children are always 2 * k | ||
| self.q = [0] | ||
| self.maxsize = maxsize | ||
|
|
||
| def _swim(self, k): | ||
| while (k > 1) and (self.q[k / 2] < self.q[k]): | ||
| self._swap(k, k / 2) | ||
| k = k / 2 | ||
|
|
||
| def _swap(self, i, j): | ||
| t = self.q[i] | ||
| self.q[i] = self.q[j] | ||
| self.q[j] = t | ||
|
|
||
| def _sink(self, k): | ||
| N = self.size() | ||
| while 2 * k <= N: | ||
| j = 2 * k | ||
| # Here we test if both children are greater than parent | ||
| # if not swap with larger one. | ||
| if j < N and self.q[j] < self.q[j + 1]: | ||
| j = j + 1 | ||
| if(self.q[k] > self.q[j]): | ||
| break | ||
| self._swap(k, j) | ||
| k = j | ||
|
|
||
| def size(self): | ||
| return len(self.q) - 1 | ||
|
|
||
| def insert(self, value): | ||
| if (self.size()) < self.maxsize: | ||
| self.q.append(value) | ||
| self._swim(self.size()) | ||
| else: | ||
| self._replaceRoot(value) | ||
|
|
||
| def getElements(self): | ||
| return self.q[1:] | ||
|
|
||
| def _replaceRoot(self, value): | ||
| if(self.q[1] > value): | ||
| self.q[1] = value | ||
| self._sink(1) | ||
|
|
||
|
|
||
| def _parse_memory(s): | ||
| """ | ||
| Parse a memory string in the format supported by Java (e.g. 1g, 200m) and | ||
|
|
@@ -232,6 +162,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): | |
| self.ctx = ctx | ||
| self._jrdd_deserializer = jrdd_deserializer | ||
| self._id = jrdd.id() | ||
| self._partitionFunc = None | ||
|
|
||
| def _toPickleSerialization(self): | ||
| if (self._jrdd_deserializer == PickleSerializer() or | ||
|
|
@@ -309,8 +240,6 @@ def getCheckpointFile(self): | |
| checkpointFile = self._jrdd.rdd().getCheckpointFile() | ||
| if checkpointFile.isDefined(): | ||
| return checkpointFile.get() | ||
| else: | ||
| return None | ||
|
|
||
| def map(self, f, preservesPartitioning=False): | ||
| """ | ||
|
|
@@ -350,7 +279,7 @@ def mapPartitions(self, f, preservesPartitioning=False): | |
| """ | ||
| def func(s, iterator): | ||
| return f(iterator) | ||
| return self.mapPartitionsWithIndex(func) | ||
| return self.mapPartitionsWithIndex(func, preservesPartitioning) | ||
|
|
||
| def mapPartitionsWithIndex(self, f, preservesPartitioning=False): | ||
| """ | ||
|
|
@@ -400,7 +329,7 @@ def filter(self, f): | |
| """ | ||
| def func(iterator): | ||
| return ifilter(f, iterator) | ||
| return self.mapPartitions(func) | ||
| return self.mapPartitions(func, True) | ||
|
|
||
| def distinct(self): | ||
| """ | ||
|
|
@@ -545,7 +474,7 @@ def intersection(self, other): | |
| """ | ||
| return self.map(lambda v: (v, None)) \ | ||
| .cogroup(other.map(lambda v: (v, None))) \ | ||
| .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \ | ||
| .filter(lambda (k, vs): all(vs)) \ | ||
| .keys() | ||
|
|
||
| def _reserialize(self, serializer=None): | ||
|
|
@@ -595,7 +524,7 @@ def sortPartition(iterator): | |
| if numPartitions == 1: | ||
| if self.getNumPartitions() > 1: | ||
| self = self.coalesce(1) | ||
| return self.mapPartitions(sortPartition) | ||
| return self.mapPartitions(sortPartition, True) | ||
|
|
||
| # first compute the boundary of each part via sampling: we want to partition | ||
| # the key-space into bins such that the bins have roughly the same | ||
|
|
@@ -700,8 +629,8 @@ def foreach(self, f): | |
| def processPartition(iterator): | ||
| for x in iterator: | ||
| f(x) | ||
| yield None | ||
| self.mapPartitions(processPartition).collect() # Force evaluation | ||
| return iter([]) | ||
| self.mapPartitions(processPartition).count() # Force evaluation | ||
|
|
||
| def foreachPartition(self, f): | ||
| """ | ||
|
|
@@ -713,7 +642,10 @@ def foreachPartition(self, f): | |
| ... yield None | ||
| >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) | ||
| """ | ||
| self.mapPartitions(f).collect() # Force evaluation | ||
| def func(it): | ||
| f(it) | ||
| return iter([]) | ||
| self.mapPartitions(func).count() # Force evaluation | ||
|
|
||
| def collect(self): | ||
| """ | ||
|
|
@@ -746,18 +678,23 @@ def reduce(self, f): | |
| 15 | ||
| >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) | ||
| 10 | ||
| >>> sc.parallelize([]).reduce(add) | ||
| Traceback (most recent call last): | ||
| ... | ||
| ValueError: Can not reduce() of empty RDD | ||
| """ | ||
| def func(iterator): | ||
| acc = None | ||
| for obj in iterator: | ||
| if acc is None: | ||
| acc = obj | ||
| else: | ||
| acc = f(obj, acc) | ||
| if acc is not None: | ||
| yield acc | ||
| iterator = iter(iterator) | ||
| try: | ||
| initial = next(iterator) | ||
| except StopIteration: | ||
| return | ||
| yield reduce(f, iterator, initial) | ||
|
|
||
| vals = self.mapPartitions(func).collect() | ||
| return reduce(f, vals) | ||
| if vals: | ||
| return reduce(f, vals) | ||
| raise ValueError("Can not reduce() of empty RDD") | ||
|
|
||
| def fold(self, zeroValue, op): | ||
| """ | ||
|
|
@@ -919,7 +856,7 @@ def countPartition(iterator): | |
| yield counts | ||
|
|
||
| def mergeMaps(m1, m2): | ||
| for (k, v) in m2.iteritems(): | ||
| for k, v in m2.iteritems(): | ||
| m1[k] += v | ||
| return m1 | ||
| return self.mapPartitions(countPartition).reduce(mergeMaps) | ||
|
|
@@ -935,18 +872,12 @@ def top(self, num): | |
| [6, 5] | ||
| """ | ||
| def topIterator(iterator): | ||
| q = [] | ||
| for k in iterator: | ||
| if len(q) < num: | ||
| heapq.heappush(q, k) | ||
| else: | ||
| heapq.heappushpop(q, k) | ||
| yield q | ||
| return [heapq.nlargest(num, iterator)] | ||
|
|
||
| def merge(a, b): | ||
| return next(topIterator(a + b)) | ||
| return heapq.nlargest(num, a + b) | ||
|
|
||
| return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True) | ||
| return self.mapPartitions(topIterator).reduce(merge) | ||
|
|
||
| def takeOrdered(self, num, key=None): | ||
| """ | ||
|
|
@@ -959,24 +890,10 @@ def takeOrdered(self, num, key=None): | |
| [10, 9, 7, 6, 5, 4] | ||
| """ | ||
|
|
||
| def topNKeyedElems(iterator, key_=None): | ||
| q = MaxHeapQ(num) | ||
| for k in iterator: | ||
| if key_ is not None: | ||
| k = (key_(k), k) | ||
| q.insert(k) | ||
| yield q.getElements() | ||
|
|
||
| def unKey(x, key_=None): | ||
| if key_ is not None: | ||
| x = [i[1] for i in x] | ||
| return x | ||
|
|
||
| def merge(a, b): | ||
| return next(topNKeyedElems(a + b)) | ||
| result = self.mapPartitions( | ||
| lambda i: topNKeyedElems(i, key)).reduce(merge) | ||
| return sorted(unKey(result, key), key=key) | ||
| return heapq.nsmallest(num, a + b, key) | ||
|
|
||
| return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge) | ||
|
|
||
| def take(self, num): | ||
| """ | ||
|
|
@@ -1016,13 +933,13 @@ def take(self, num): | |
| left = num - len(items) | ||
|
|
||
| def takeUpToNumLeft(iterator): | ||
| iterator = iter(iterator) | ||
| taken = 0 | ||
| while taken < left: | ||
| yield next(iterator) | ||
| taken += 1 | ||
|
|
||
| p = range( | ||
| partsScanned, min(partsScanned + numPartsToTry, totalParts)) | ||
| p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) | ||
| res = self.context.runJob(self, takeUpToNumLeft, p, True) | ||
|
|
||
| items += res | ||
|
|
@@ -1036,8 +953,15 @@ def first(self): | |
|
|
||
| >>> sc.parallelize([2, 3, 4]).first() | ||
| 2 | ||
| >>> sc.parallelize([]).first() | ||
| Traceback (most recent call last): | ||
| ... | ||
| ValueError: RDD is empty | ||
| """ | ||
| return self.take(1)[0] | ||
| rs = self.take(1) | ||
| if rs: | ||
| return rs[0] | ||
| raise ValueError("RDD is empty") | ||
|
|
||
| def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): | ||
| """ | ||
|
|
@@ -1262,13 +1186,13 @@ def reduceByKeyLocally(self, func): | |
| """ | ||
| def reducePartition(iterator): | ||
| m = {} | ||
| for (k, v) in iterator: | ||
| m[k] = v if k not in m else func(m[k], v) | ||
| for k, v in iterator: | ||
| m[k] = func(m[k], v) if k in m else v | ||
| yield m | ||
|
|
||
| def mergeMaps(m1, m2): | ||
| for (k, v) in m2.iteritems(): | ||
| m1[k] = v if k not in m1 else func(m1[k], v) | ||
| for k, v in m2.iteritems(): | ||
| m1[k] = func(m1[k], v) if k in m1 else v | ||
| return m1 | ||
| return self.mapPartitions(reducePartition).reduce(mergeMaps) | ||
|
|
||
|
|
@@ -1365,7 +1289,7 @@ def add_shuffle_key(split, iterator): | |
| buckets = defaultdict(list) | ||
| c, batch = 0, min(10 * numPartitions, 1000) | ||
|
|
||
| for (k, v) in iterator: | ||
| for k, v in iterator: | ||
| buckets[partitionFunc(k) % numPartitions].append((k, v)) | ||
| c += 1 | ||
|
|
||
|
|
@@ -1388,7 +1312,7 @@ def add_shuffle_key(split, iterator): | |
| batch = max(batch / 1.5, 1) | ||
| c = 0 | ||
|
|
||
| for (split, items) in buckets.iteritems(): | ||
| for split, items in buckets.iteritems(): | ||
| yield pack_long(split) | ||
| yield outputSerializer.dumps(items) | ||
|
|
||
|
|
@@ -1458,7 +1382,7 @@ def _mergeCombiners(iterator): | |
| merger.mergeCombiners(iterator) | ||
| return merger.iteritems() | ||
|
|
||
| return shuffled.mapPartitions(_mergeCombiners) | ||
| return shuffled.mapPartitions(_mergeCombiners, True) | ||
|
|
||
| def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): | ||
| """ | ||
|
|
@@ -1522,7 +1446,6 @@ def mergeCombiners(a, b): | |
| return self.combineByKey(createCombiner, mergeValue, mergeCombiners, | ||
| numPartitions).mapValues(lambda x: ResultIterable(x)) | ||
|
|
||
| # TODO: add tests | ||
| def flatMapValues(self, f): | ||
| """ | ||
| Pass each value in the key-value pair RDD through a flatMap function | ||
|
|
@@ -1612,9 +1535,9 @@ def subtractByKey(self, other, numPartitions=None): | |
| [('b', 4), ('b', 5)] | ||
| """ | ||
| def filter_func((key, vals)): | ||
| return len(vals[0]) > 0 and len(vals[1]) == 0 | ||
| return vals[0] and not vals[1] | ||
| map_func = lambda (key, vals): [(key, val) for val in vals[0]] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can delete this line, now that it's unused. |
||
| return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) | ||
| return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0]) | ||
|
|
||
| def subtract(self, other, numPartitions=None): | ||
| """ | ||
|
|
@@ -1627,7 +1550,7 @@ def subtract(self, other, numPartitions=None): | |
| """ | ||
| # note: here 'True' is just a placeholder | ||
| rdd = other.map(lambda x: (x, True)) | ||
| return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) | ||
| return self.map(lambda x: (x, True)).subtractByKey(rdd, numPartitions).keys() | ||
|
|
||
| def keyBy(self, f): | ||
| """ | ||
|
|
@@ -1720,9 +1643,8 @@ def name(self): | |
| Return the name of this RDD. | ||
| """ | ||
| name_ = self._jrdd.name() | ||
| if not name_: | ||
| return None | ||
| return name_.encode('utf-8') | ||
| if name_: | ||
| return name_.encode('utf-8') | ||
|
|
||
| def setName(self, name): | ||
| """ | ||
|
|
@@ -1740,9 +1662,8 @@ def toDebugString(self): | |
| A description of this RDD and its recursive dependencies for debugging. | ||
| """ | ||
| debug_string = self._jrdd.toDebugString() | ||
| if not debug_string: | ||
| return None | ||
| return debug_string.encode('utf-8') | ||
| if debug_string: | ||
| return debug_string.encode('utf-8') | ||
|
|
||
| def getStorageLevel(self): | ||
| """ | ||
|
|
@@ -1791,12 +1712,12 @@ def lookup(self, key): | |
| >>> sorted.lookup(42) # fast | ||
| [42] | ||
| """ | ||
| values = self.filter(lambda (k, v): k == key).values() | ||
| self = self.filter(lambda (k, v): k == key).values() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This reassignment to |
||
|
|
||
| if hasattr(self, "_partitionFunc"): | ||
| if self._partitionFunc is not None: | ||
| return self.ctx.runJob(self, lambda x: x, [self._partitionFunc(key)], False) | ||
|
|
||
| return values.collect() | ||
| return self.collect() | ||
|
|
||
|
|
||
| class PipelinedRDD(RDD): | ||
|
|
@@ -1842,6 +1763,7 @@ def pipeline_func(split, iterator): | |
| self._jrdd_val = None | ||
| self._jrdd_deserializer = self.ctx.serializer | ||
| self._bypass_serializer = False | ||
| self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None | ||
|
|
||
| @property | ||
| def _jrdd(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit, but I'd drop the 'of' and just say "Cannot reduce() empty RDD"