From cbd58e3e752bf81d3c91f2752f8e6757f2dfabba Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 14 Nov 2014 00:39:52 -0800 Subject: [PATCH 1/2] specialize sc.parallelize(xrange) --- python/pyspark/context.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index faa5952258ae..0df979baa60c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -289,12 +289,31 @@ def stop(self): def parallelize(self, c, numSlices=None): """ - Distribute a local Python collection to form an RDD. - - >>> sc.parallelize(range(5), 5).glom().collect() - [[0], [1], [2], [3], [4]] - """ - numSlices = numSlices or self.defaultParallelism + Distribute a local Python collection to form an RDD. Use xrange if + the input represents a range for performance. + + >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect() + [[0], [2], [3], [4], [6]] + >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect() + [[], [0], [], [2], [4]] + """ + numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism + if isinstance(c, xrange): + size = len(c) + if size == 0: + return self.parallelize([], numSlices) + step = c[1] - c[0] if size > 1 else 1 + c1 = xrange(c[0], c[0] + (size + 1) * step, step) + + def getStartIndex(split): + return split * size / numSlices + + def f(split, iterator): + startIndex = getStartIndex(split) + endIndex = getStartIndex(split + 1) + return xrange(c1[startIndex], c1[endIndex], step) + + return self.parallelize([], numSlices).mapPartitionsWithIndex(f) # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). From 8953c415f025411e2159858e7800971eae2460f2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 14 Nov 2014 09:41:45 -0800 Subject: [PATCH 2/2] follow davies' suggestion --- python/pyspark/context.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0df979baa60c..b6c991453d4d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -289,8 +289,8 @@ def stop(self): def parallelize(self, c, numSlices=None): """ - Distribute a local Python collection to form an RDD. Use xrange if - the input represents a range for performance. + Distribute a local Python collection to form an RDD. Using xrange + is recommended if the input represents a range for performance. >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect() [[0], [2], [3], [4], [6]] @@ -303,15 +303,13 @@ def parallelize(self, c, numSlices=None): if size == 0: return self.parallelize([], numSlices) step = c[1] - c[0] if size > 1 else 1 - c1 = xrange(c[0], c[0] + (size + 1) * step, step) + start0 = c[0] - def getStartIndex(split): - return split * size / numSlices + def getStart(split): + return start0 + (split * size / numSlices) * step def f(split, iterator): - startIndex = getStartIndex(split) - endIndex = getStartIndex(split + 1) - return xrange(c1[startIndex], c1[endIndex], step) + return xrange(getStart(split), getStart(split + 1), step) return self.parallelize([], numSlices).mapPartitionsWithIndex(f) # Calling the Java parallelize() method with an ArrayList is too slow,