Skip to content
21 changes: 17 additions & 4 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag

import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
val handleFunc = (sock: Socket) => {
val out = new DataOutputStream(sock.getOutputStream)
val in = new DataInputStream(sock.getInputStream)
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
var result: Array[Any] = null
rdd.sparkContext.submitJob(
rdd,
(iter: Iterator[Any]) => iter.toArray,
Seq(i), // The partition we are evaluating
(_, res: Array[Any]) => result = res,
result)
}
val prefetchIter = collectPartitionIter.buffered

// Write data until iteration is complete, client stops iteration, or error occurs
var complete = false
Expand All @@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging {
// Read request for data, value of zero will stop iteration or non-zero to continue
if (in.readInt() == 0) {
complete = true
} else if (collectPartitionIter.hasNext) {
} else if (prefetchIter.hasNext) {

// Client requested more data, attempt to collect the next partition
val partitionArray = collectPartitionIter.next()
val partitionFuture = prefetchIter.next()
// Cause the next job to be submitted if prefetchPartitions is enabled.
if (prefetchPartitions) {
prefetchIter.headOption
}
val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be best to avoid awaitResult if possible. Could you make a buffered iterator yourself?
maybe something like

var next = collectPartitionIter.next()
val prefetchIter = collectPartitionIter.map { part =>
  val tmp = next
  next = part
  tmp
} ++ Iterator(next)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the awaitFuture (or something similar) is required for us to use futures. If we just used a buffered iterator without allowing the job to schedule separately we'd just block for both partitions right away instead of evaluating the other future in the background while we block on the first. (Implicitly this awaitResult is already effectively done inside of the previous DAGScheduler's runJob.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, you are totally right. That would block while getting the prefetched partition. This looks pretty good to me then.

One question though, when should the first job be triggered? I think the old behavior used to start the first job as soon as toLocalIterator() was called. From what I can tell, this will wait until the first iteration and then trigger the first 2 jobs. Either way is probably fine, but you might get slightly better performance by starting the first job immediately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In either case it waits for reading a request of data from the Python side before starting a job, because the map on the partition indices is lazily evaluated.


// Send response there is a partition to read
out.writeInt(1)
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,17 +2444,23 @@ def countApproxDistinct(self, relativeSD=0.05):
hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)

def toLocalIterator(self):
def toLocalIterator(self, prefetchPartitions=False):
"""
Return an iterator that contains all of the elements in this RDD.
The iterator will consume as much memory as the largest partition in this RDD.
With prefetch it may consume up to the memory of the 2 largest partitions.

:param prefetchPartitions: If Spark should pre-fetch the next partition
before it is needed.

>>> rdd = sc.parallelize(range(10))
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
with SCCallSiteSync(self.context) as css:
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(
self._jrdd.rdd(),
prefetchPartitions)
return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)

def barrier(self):
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,20 @@ def collect(self):

@ignore_unicode_prefix
@since(2.0)
def toLocalIterator(self):
def toLocalIterator(self, prefetchPartitions=False):
"""
Returns an iterator that contains all of the rows in this :class:`DataFrame`.
The iterator will consume as much memory as the largest partition in this DataFrame.
With prefetch it may consume up to the memory of the 2 largest partitions.

:param prefetchPartitions: If Spark should pre-fetch the next partition
before it is needed.

>>> list(df.toLocalIterator())
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.toPythonIterator()
sock_info = self._jdf.toPythonIterator(prefetchPartitions)
return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))

@ignore_unicode_prefix
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,12 @@ def test_to_local_iterator(self):
expected = df.collect()
self.assertEqual(expected, list(it))

def test_to_local_iterator_prefetch(self):
df = self.spark.range(8, numPartitions=4)
expected = df.collect()
it = df.toLocalIterator(prefetchPartitions=True)
self.assertEqual(expected, list(it))

def test_to_local_iterator_not_fully_consumed(self):
# SPARK-23961: toLocalIterator throws exception when not fully consumed
# Create a DataFrame large enough so that write to socket will eventually block
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/tests/test_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime, timedelta
import hashlib
import os
import random
import sys
import tempfile
import time
from glob import glob

from py4j.protocol import Py4JJavaError
Expand Down Expand Up @@ -68,6 +70,26 @@ def test_to_localiterator(self):
it2 = rdd2.toLocalIterator()
self.assertEqual([1, 2, 3], sorted(it2))

def test_to_localiterator_prefetch(self):
# Test that we fetch the next partition in parallel
# We do this by returning the current time and:
# reading the first elem, waiting, and reading the second elem
# If not in parallel then these would be at different times
# But since they are being computed in parallel we see the time
# is "close enough" to the same.
rdd = self.sc.parallelize(range(2), 2)
times1 = rdd.map(lambda x: datetime.now())
times2 = rdd.map(lambda x: datetime.now())
times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True)
times_iter = times2.toLocalIterator(prefetchPartitions=False)
times_prefetch_head = next(times_iter_prefetch)
times_head = next(times_iter)
time.sleep(2)
times_next = next(times_iter)
times_prefetch_next = next(times_iter_prefetch)
self.assertTrue(times_next - times_head >= timedelta(seconds=2))
self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1))

def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3387,9 +3387,9 @@ class Dataset[T] private[sql](
}
}

private[sql] def toPythonIterator(): Array[Any] = {
private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions)
}
}

Expand Down