diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4d76ff76e675..7cbfb71beea3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.duration.Duration import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration @@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, the secret for authentication, and a socket auth * server object that can be used to join the JVM serving thread in Python. */ - def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { + def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = { val handleFunc = (sock: Socket) => { val out = new DataOutputStream(sock.getOutputStream) val in = new DataInputStream(sock.getInputStream) Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Collects a partition on each iteration val collectPartitionIter = rdd.partitions.indices.iterator.map { i => - rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head + var result: Array[Any] = null + rdd.sparkContext.submitJob( + rdd, + (iter: Iterator[Any]) => iter.toArray, + Seq(i), // The partition we are evaluating + (_, res: Array[Any]) => result = res, + result) } + val prefetchIter = collectPartitionIter.buffered // Write data until iteration is complete, client stops iteration, or error occurs var complete = false @@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging { // Read request for data, value of zero will stop iteration or non-zero to continue if (in.readInt() == 0) { complete = true - } else if (collectPartitionIter.hasNext) { + } else if (prefetchIter.hasNext) { // Client requested more data, attempt to collect the next partition - val partitionArray = collectPartitionIter.next() + val partitionFuture = prefetchIter.next() + // Cause the next job to be submitted if prefetchPartitions is enabled. + if (prefetchPartitions) { + prefetchIter.headOption + } + val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf) // Send response there is a partition to read out.writeInt(1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 590e8e1e9c07..17f159053900 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2444,17 +2444,23 @@ def countApproxDistinct(self, relativeSD=0.05): hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) - def toLocalIterator(self): + def toLocalIterator(self, prefetchPartitions=False): """ Return an iterator that contains all of the elements in this RDD. The iterator will consume as much memory as the largest partition in this RDD. + With prefetch it may consume up to the memory of the 2 largest partitions. + + :param prefetchPartitions: If Spark should pre-fetch the next partition + before it is needed. >>> rdd = sc.parallelize(range(10)) >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe( + self._jrdd.rdd(), + prefetchPartitions) return _local_iterator_from_socket(sock_info, self._jrdd_deserializer) def barrier(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 398471234d2b..03b37fa7d0d9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -520,16 +520,20 @@ def collect(self): @ignore_unicode_prefix @since(2.0) - def toLocalIterator(self): + def toLocalIterator(self, prefetchPartitions=False): """ Returns an iterator that contains all of the rows in this :class:`DataFrame`. The iterator will consume as much memory as the largest partition in this DataFrame. + With prefetch it may consume up to the memory of the 2 largest partitions. + + :param prefetchPartitions: If Spark should pre-fetch the next partition + before it is needed. >>> list(df.toLocalIterator()) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - sock_info = self._jdf.toPythonIterator() + sock_info = self._jdf.toPythonIterator(prefetchPartitions) return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 5550a093bf80..df3fb4da65f2 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -689,6 +689,12 @@ def test_to_local_iterator(self): expected = df.collect() self.assertEqual(expected, list(it)) + def test_to_local_iterator_prefetch(self): + df = self.spark.range(8, numPartitions=4) + expected = df.collect() + it = df.toLocalIterator(prefetchPartitions=True) + self.assertEqual(expected, list(it)) + def test_to_local_iterator_not_fully_consumed(self): # SPARK-23961: toLocalIterator throws exception when not fully consumed # Create a DataFrame large enough so that write to socket will eventually block diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index bff080362085..e7a7971dfc9a 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from datetime import datetime, timedelta import hashlib import os import random import sys import tempfile +import time from glob import glob from py4j.protocol import Py4JJavaError @@ -68,6 +70,26 @@ def test_to_localiterator(self): it2 = rdd2.toLocalIterator() self.assertEqual([1, 2, 3], sorted(it2)) + def test_to_localiterator_prefetch(self): + # Test that we fetch the next partition in parallel + # We do this by returning the current time and: + # reading the first elem, waiting, and reading the second elem + # If not in parallel then these would be at different times + # But since they are being computed in parallel we see the time + # is "close enough" to the same. + rdd = self.sc.parallelize(range(2), 2) + times1 = rdd.map(lambda x: datetime.now()) + times2 = rdd.map(lambda x: datetime.now()) + times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True) + times_iter = times2.toLocalIterator(prefetchPartitions=False) + times_prefetch_head = next(times_iter_prefetch) + times_head = next(times_iter) + time.sleep(2) + times_next = next(times_iter) + times_prefetch_next = next(times_iter_prefetch) + self.assertTrue(times_next - times_head >= timedelta(seconds=2)) + self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1)) + def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 87f4c8f5d949..dc5670c78c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3387,9 +3387,9 @@ class Dataset[T] private[sql]( } } - private[sql] def toPythonIterator(): Array[Any] = { + private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = { withNewExecutionId { - PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) + PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions) } }