Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,52 @@ def batch_as(rdd, batchSize):
other._jrdd_deserializer)
return RDD(pairRDD, self.ctx, deserializer)

def zipWithIndex(self):
"""
Zips this RDD with its element indices.

The ordering is first based on the partition index and then the
ordering of items within each partition. So the first item in
the first partition gets index 0, and the last item in the last
partition receives the largest index.

This method needs to trigger a spark job when this RDD contains
more than one partitions.

>>> sc.parallelize(range(4), 2).zipWithIndex().collect()
[(0, 0), (1, 1), (2, 2), (3, 3)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the best example because it's not clear which element is the item and which element is its index. In the Scala API, this is clear from the method's return type. Maybe we should update the documentation to explicitly state that the second element is the id (like the Scala API).

I think this implementation has things backwards w.r.t. the Scala one:

>>> sc.parallelize(['a', 'b', 'c', 'd'], 2).zipWithIndex().collect()
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'd')]

versus

scala> sc.parallelize(Seq('a', 'b', 'c', 'd')).zipWithIndex().collect()
res0: Array[(Char, Long)] = Array((a,0), (b,1), (c,2), (d,3))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change it.

"""
starts = [0]
if self.getNumPartitions() > 1:
nums = self.mapPartitions(lambda it: [sum(1 for i in it)]).collect()
for i in range(len(nums) - 1):
starts.append(starts[-1] + nums[i])

def func(k, it):
return enumerate(it, starts[k])

return self.mapPartitionsWithIndex(func)

def zipWithUniqueId(self):
"""
Zips this RDD with generated unique Long ids.

Items in the kth partition will get ids k, n+k, 2*n+k, ..., where
n is the number of partitions. So there may exist gaps, but this
method won't trigger a spark job, which is different from
L{zipWithIndex}

>>> sc.parallelize(range(4), 2).zipWithUniqueId().collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, it might be better to use three partitions (or some other value) so that there's a gap in the ids.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea.

[(0, 0), (2, 1), (1, 2), (3, 3)]
"""
n = self.getNumPartitions()

def func(k, it):
for i, v in enumerate(it):
yield i * n + k, v

return self.mapPartitionsWithIndex(func)

def name(self):
"""
Return the name of this RDD.
Expand Down