-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2871] [PySpark] Add missing API #1791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
ff2cbe3
e0b3d30
5d5be95
a95eca0
4ffae00
7a9ea0a
53640be
9a01ac3
7ba5f88
a25c34e
1218b3b
034124f
9132456
977e474
ac606ca
f0158e4
cb4f712
96713fa
e9e1037
63c013d
1213aca
28fd368
1ac98d6
657a09b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -259,6 +259,17 @@ def defaultMinPartitions(self): | |
| """ | ||
| return self._jsc.sc().defaultMinPartitions() | ||
|
|
||
| @property | ||
| def isLocal(self): | ||
| """ | ||
| Whether the context run locally | ||
| """ | ||
| return self._jsc.isLocal() | ||
|
|
||
| @property | ||
| def conf(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs a docstring. Also, the Scala equivalent of this clones the SparkConf because it cannot be changed at runtime. We might want to do the same thing here (to guard against misuse); I'm not sure how clone() interacts with Py4J objects; do we need to implement a custom clone method for objects with Py4J objects as fields that calls those objects' JVM clone methods? |
||
| return self._conf | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with Josh here, you need to clone the conf before returning it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will return an read-only copy of it. |
||
|
|
||
| def stop(self): | ||
| """ | ||
| Shut down the SparkContext. | ||
|
|
@@ -724,6 +735,13 @@ def sparkUser(self): | |
| """ | ||
| return self._jsc.sc().sparkUser() | ||
|
|
||
| @property | ||
| def startTime(self): | ||
| """ | ||
| Return the start time of context in millis seconds | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw it in Java API docs,so add it here.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The primary use of this, outside of SparkContext, seems to be printing the context's uptime. So, why not add an
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change it to uptime will not improve anything, or remove it? |
||
| """ | ||
| return self._jsc.startTime() | ||
|
|
||
| def cancelJobGroup(self, groupId): | ||
| """ | ||
| Cancel active jobs for the specified group. See L{SparkContext.setJobGroup} | ||
|
|
@@ -763,6 +781,15 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): | |
| it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) | ||
| return list(mappedRDD._collect_iterator_through_file(it)) | ||
|
|
||
| def runApproximateJob(self, rdd, func, evaluator, timeout): | ||
| """ | ||
| :: DeveloperApi :: | ||
| Run a job that can return approximate results. | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| def _test(): | ||
| import atexit | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ | |
| import heapq | ||
| from random import Random | ||
| from math import sqrt, log | ||
| import array | ||
|
|
||
| from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ | ||
| BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ | ||
|
|
@@ -734,6 +735,19 @@ def _collect_iterator_through_file(self, iterator): | |
| yield item | ||
| os.unlink(tempFile.name) | ||
|
|
||
| def collectPartitions(self, partitions): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the Scala API, this is marked as a private API used only for tests. Is there a non-test usecase for this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will help for debug, you can collect parts of the RDD to investigate with them. It also be helpful if we have an API called slice(start, [end]) to select parts of the partitions. DPark has this kind of API, it help us a lot, Narrow down the data to do fast debugging.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with Josh, let's delete this for now. We can open a separate JIRA about making it public and maybe discuss there.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW I do like a slice-based API in general, that might be what we propose publicly. |
||
| """ | ||
| Return a list of list that contains all of the elements in a specific | ||
| partition of this RDD. | ||
|
|
||
| >>> rdd = sc.parallelize(range(8), 4) | ||
| >>> rdd.collectPartitions([1, 3]) | ||
| [[2, 3], [6, 7]] | ||
| """ | ||
|
|
||
| return [self.ctx.runJob(self, lambda it: it, [p], True) | ||
| for p in partitions] | ||
|
|
||
| def reduce(self, f): | ||
| """ | ||
| Reduces the elements of this RDD using the specified commutative and | ||
|
|
@@ -808,23 +822,39 @@ def func(iterator): | |
|
|
||
| return self.mapPartitions(func).fold(zeroValue, combOp) | ||
|
|
||
| def max(self): | ||
| def max(self, comp=None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe explain what "comp" is in the doc comment |
||
| """ | ||
| Find the maximum item in this RDD. | ||
|
|
||
| >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max() | ||
| >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) | ||
| >>> rdd.max() | ||
| 43.0 | ||
| >>> rdd.max(lambda a, b: cmp(str(a), str(b))) | ||
| 5.0 | ||
| """ | ||
| return self.reduce(max) | ||
| if comp is not None: | ||
| func = lambda a, b: a if comp(a, b) >= 0 else b | ||
| else: | ||
| func = max | ||
|
|
||
| def min(self): | ||
| return self.reduce(func) | ||
|
|
||
| def min(self, comp=None): | ||
| """ | ||
| Find the minimum item in this RDD. | ||
|
|
||
| >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min() | ||
| >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) | ||
| >>> rdd.min() | ||
| 1.0 | ||
| >>> rdd.min(lambda a, b: cmp(str(a), str(b))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In |
||
| 1.0 | ||
| """ | ||
| return self.reduce(min) | ||
| if comp is not None: | ||
| func = lambda a, b: a if comp(a, b) <= 0 else b | ||
| else: | ||
| func = min | ||
|
|
||
| return self.reduce(func) | ||
|
|
||
| def sum(self): | ||
| """ | ||
|
|
@@ -854,6 +884,59 @@ def redFunc(left_counter, right_counter): | |
|
|
||
| return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc) | ||
|
|
||
| def histogram(self, buckets=None, even=False): | ||
| """ | ||
| Compute a histogram of the data. | ||
|
|
||
| Compute a histogram using the provided buckets. The buckets | ||
| are all open to the left except for the last which is closed | ||
| e.g. for the array [1,10,20,50] the buckets are [1,10) [10,20) | ||
| [20,50] e.g 1<=x<10, 10<=x<20, 20<=x<50. And on the input of 1 | ||
| and 50 we would have a histogram of 1,0,0. | ||
|
|
||
| Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) | ||
| this can be switched from an O(log n) inseration to O(1) per | ||
| element(where n = # buckets), if you set evenBuckets to true. | ||
| Buckets must be sorted and not contain any duplicates, Buckets | ||
| array must be at least two elements All NaN entries are treated | ||
| the same. If you have a NaN bucket it must be the maximum value | ||
| of the last position and all NaN entries will be counted in that | ||
| bucket. | ||
|
|
||
| If buckets is a number, it will generates buckets which is | ||
| evenly spaced between the minimum and maximum of the RDD. For | ||
| example, if the min value is 0 and the max is 100, given buckets | ||
| as 2, the resulting buckets will be [0,50) [50,100]. buckets must | ||
| be at least 1 If the RDD contains infinity, NaN throws an exception | ||
| If the elements in RDD do not vary (max == min) always returns | ||
| a single bucket. It will return an tuple of buckets and histogram | ||
| in them. | ||
|
|
||
| >>> rdd = sc.parallelize(range(51)) | ||
| >>> rdd.histogram(2) | ||
| ([0.0, 25.0, 50.0], [25L, 26L]) | ||
| >>> rdd.histogram([0, 5, 25, 50]) | ||
| [5L, 20L, 26L] | ||
| >>> rdd.histogram([0, 15, 30, 45, 60], True) | ||
| [15L, 15L, 15L, 6L] | ||
| """ | ||
|
|
||
| drdd = self.map(lambda x:float(x)) | ||
| batched = isinstance(drdd._jrdd_deserializer, BatchedSerializer) | ||
| jdrdd = self.ctx._jvm.PythonRDD.pythonToJavaDouble(drdd._jrdd, batched) | ||
|
|
||
| if isinstance(buckets, (int,long)): | ||
| if buckets < 1: | ||
| raise ValueError("buckets should be greater than 1") | ||
|
|
||
| r = jdrdd.histogram(buckets) | ||
| return list(r._1()), list(r._2()) | ||
|
|
||
| jbuckets = self.ctx._gateway.new_array(self.ctx._gateway.jvm.java.lang.Double, len(buckets)) | ||
| for i in range(len(buckets)): | ||
| jbuckets[i] = float(buckets[i]) | ||
| return list(jdrdd.histogram(jbuckets, even)) | ||
|
|
||
| def mean(self): | ||
| """ | ||
| Compute the mean of this RDD's elements. | ||
|
|
@@ -872,6 +955,7 @@ def variance(self): | |
| """ | ||
| return self.stats().variance() | ||
|
|
||
|
|
||
| def stdev(self): | ||
| """ | ||
| Compute the standard deviation of this RDD's elements. | ||
|
|
@@ -1673,11 +1757,57 @@ def zip(self, other): | |
| >>> x.zip(y).collect() | ||
| [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] | ||
| """ | ||
| if self.getNumPartitions() != other.getNumPartitions(): | ||
| raise ValueError("the number of partitions dose not match" | ||
| " with each other") | ||
|
|
||
| pairRDD = self._jrdd.zip(other._jrdd) | ||
| deserializer = PairDeserializer(self._jrdd_deserializer, | ||
| other._jrdd_deserializer) | ||
| return RDD(pairRDD, self.ctx, deserializer) | ||
|
|
||
| def zipPartitions(self, other, f, preservesPartitioning=False): | ||
| """ | ||
| Zip this RDD's partitions with one (or more) RDD(s) and return a | ||
| new RDD by applying a function to the zipped partitions. | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def zipWithIndex(self): | ||
| """ | ||
| Zips this RDD with its element indices. | ||
|
||
|
|
||
| >>> sc.parallelize(range(4), 2).zipWithIndex().collect() | ||
| [(0, 0), (1, 1), (2, 2), (3, 3)] | ||
| """ | ||
| nums = self.glom().map(lambda it: sum(1 for i in it)).collect() | ||
| starts = [0] | ||
| for i in range(len(nums) - 1): | ||
| starts.append(starts[-1] + nums[i]) | ||
|
|
||
| def func(k, it): | ||
| for i, v in enumerate(it): | ||
| yield starts[k] + i, v | ||
|
|
||
| return self.mapPartitionsWithIndex(func) | ||
|
|
||
| def zipWithUniqueId(self): | ||
| """ | ||
| Zips this RDD with generated unique Long ids. | ||
|
||
|
|
||
| >>> sc.parallelize(range(4), 2).zipWithUniqueId().collect() | ||
| [(0, 0), (2, 1), (1, 2), (3, 3)] | ||
| """ | ||
| n = self.getNumPartitions() | ||
|
|
||
| def func(k, it): | ||
| for i, v in enumerate(it): | ||
| yield i * n + k, v | ||
|
|
||
| return self.mapPartitionsWithIndex(func) | ||
|
|
||
| def name(self): | ||
| """ | ||
| Return the name of this RDD. | ||
|
|
@@ -1743,6 +1873,79 @@ def _defaultReducePartitions(self): | |
| # on the key; we need to compare the hash of the key to the hash of the | ||
| # keys in the pairs. This could be an expensive operation, since those | ||
| # hashes aren't retained. | ||
| def lookup(self, key): | ||
| """ | ||
| Return the list of values in the RDD for key key. | ||
|
|
||
| Not Implemented | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def countApprox(self, timeout, confidence=0.95): | ||
| """ | ||
| :: Experimental :: | ||
| Approximate version of count() that returns a potentially incomplete | ||
| result within a timeout, even if not all tasks have finished. | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def countApproxDistinct(self, timeout, confidence=0.95): | ||
| """ | ||
| :: Experimental :: | ||
| Return approximate number of distinct elements in the RDD. | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def countByValueApprox(self, timeout, confidence=0.95): | ||
| """ | ||
| :: Experimental:: | ||
| Approximate version of countByValue(). | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def sumApprox(self, timeout, confidence=0.95): | ||
| """ | ||
| :: Experimental :: | ||
| Approximate operation to return the sum within a timeout | ||
| or meet the confidence. | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def meanApprox(self, timeout, confidence=0.95): | ||
| """ | ||
| :: Experimental :: | ||
| Approximate operation to return the mean within a timeout | ||
| or meet the confidence. | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def countApproxDistinctByKey(self, timeout, confidence=0.95): | ||
| """ | ||
| :: Experimental :: | ||
| Return approximate number of distinct values for each key in this RDD. | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def countByKeyApprox(self, timeout, confidence=0.95): | ||
| """ | ||
| :: Experimental :: | ||
| Approximate version of countByKey that can return a partial result if it does not finish within a timeout. | ||
|
|
||
| Not implemented. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class PipelinedRDD(RDD): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convert a RDD of Java objects to and RDD of serialized Python objects
=>
Convert an RDD of Java objects to an RDD of serialized Python objects ?