Skip to content
21 changes: 17 additions & 4 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag

import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
val handleFunc = (sock: Socket) => {
val out = new DataOutputStream(sock.getOutputStream)
val in = new DataInputStream(sock.getInputStream)
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
var result: Array[Any] = null
rdd.sparkContext.submitJob(
rdd,
(iter: Iterator[Any]) => iter.toArray,
Seq(i), // The partition we are evaluating
(_, res: Array[Any]) => result = res,
result)
}
val prefetchIter = collectPartitionIter.buffered

// Write data until iteration is complete, client stops iteration, or error occurs
var complete = false
Expand All @@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging {
// Read request for data, value of zero will stop iteration or non-zero to continue
if (in.readInt() == 0) {
complete = true
} else if (collectPartitionIter.hasNext) {
} else if (prefetchIter.hasNext) {

// Client requested more data, attempt to collect the next partition
val partitionArray = collectPartitionIter.next()
val partitionFuture = prefetchIter.next()
// Cause the next job to be submitted if prefecthPartitions is enabled.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: prefecthPartitions -> prefetchPartitions

if (prefetchPartitions) {
prefetchIter.headOption
}
val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be best to avoid awaitResult if possible. Could you make a buffered iterator yourself?
maybe something like

var next = collectPartitionIter.next()
val prefetchIter = collectPartitionIter.map { part =>
  val tmp = next
  next = part
  tmp
} ++ Iterator(next)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the awaitFuture (or something similar) is required for us to use futures. If we just used a buffered iterator without allowing the job to schedule separately we'd just block for both partitions right away instead of evaluating the other future in the background while we block on the first. (Implicitly this awaitResult is already effectively done inside of the previous DAGScheduler's runJob.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, you are totally right. That would block while getting the prefetched partition. This looks pretty good to me then.

One question though, when should the first job be triggered? I think the old behavior used to start the first job as soon as toLocalIterator() was called. From what I can tell, this will wait until the first iteration and then trigger the first 2 jobs. Either way is probably fine, but you might get slightly better performance by starting the first job immediately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In either case it waits for reading a request of data from the Python side before starting a job, because the map on the partition indices is lazily evaluated.


// Send response there is a partition to read
out.writeInt(1)
Expand Down
86 changes: 86 additions & 0 deletions examples/src/main/python/prefetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
Copy link
Member

@HyukjinKwon HyukjinKwon Aug 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think examples in this directory target to show how the feature or API is used rather than showing the perf results .. - I think it can be just shown in the PR description.
Virtually the example seems it has to be just .toLocalIterator(prefetchPartitions=True) which I don't think worth as a separate example file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reasonable, I'll remove it from the examples, was mostly a simple way to share the microbenchmark.

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import print_function

import sys
import timeit
from random import random
from operator import add

from pyspark.sql import SparkSession


if __name__ == "__main__":
"""
Usage: prefetch [partitions] [iterations]
Uses timeit to demonstrate the benefit of prefetch.
"""
spark = SparkSession\
.builder\
.appName("PrefetchDemo")\
.getOrCreate()

partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 4
iterations = int(sys.argv[2]) if len(sys.argv) > 2 else 4

elems = spark.sparkContext.parallelize(range(1, partitions * 2), partitions)
elems.cache()
elems.count()

def slowCompute(elem):
"""
Wait ten seconds to simulate some computation then return.
"""
import time
time.sleep(10)
return elem

def localCompute(elem):
"""
Simulate processing the data locally.
"""
import time
time.sleep(1)
return elem

def fetchWithPrefetch():
prefetchIter = elems.mapPartitions(slowCompute).toLocalIterator(prefetchPartitions=True)
localCollection = list(map(localCompute, prefetchIter))
return localCollection

def fetchRegular():
regularIter = elems.mapPartitions(slowCompute).toLocalIterator(prefetchPartitions=False)
localCollection = list(map(localCompute, regularIter))
return localCollection

print("Running timers:\n")
prefetchTimer = timeit.Timer(fetchWithPrefetch)
prefetchTime = prefetchTimer.timeit(number=iterations)

regularTimer = timeit.Timer(fetchRegular)
regularTime = regularTimer.timeit(number=iterations)
print("\nResults:\n")
print("Prefetch time:\n")
print(prefetchTime)
print("\n")

print("Regular time:\n")
print(regularTime)
print("\n")

spark.stop()
9 changes: 7 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,17 +2444,22 @@ def countApproxDistinct(self, relativeSD=0.05):
hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)

def toLocalIterator(self):
def toLocalIterator(self, prefetchPartitions=False):
"""
Return an iterator that contains all of the elements in this RDD.
The iterator will consume as much memory as the largest partition in this RDD.
With prefetch it may consume up to the memory of the 2 largest partitions.

:param prefetchPartitions: If Spark should pre-fetch the next partition
before it is needed.
>>> rdd = sc.parallelize(range(10))
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
with SCCallSiteSync(self.context) as css:
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(
self._jrdd.rdd(),
prefetchPartitions)
return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)

def barrier(self):
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,14 @@ def collect(self):

@ignore_unicode_prefix
@since(2.0)
def toLocalIterator(self):
def toLocalIterator(self, prefetchPartitions=False):
"""
Returns an iterator that contains all of the rows in this :class:`DataFrame`.
The iterator will consume as much memory as the largest partition in this DataFrame.
With prefetch it may consume up to the memory of the 2 largest partitions.

:param prefetchPartitions: If Spark should pre-fetch the next partition
before it is needed.
>>> list(df.toLocalIterator())
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,12 @@ def test_to_local_iterator(self):
expected = df.collect()
self.assertEqual(expected, list(it))

def test_to_local_iterator_prefetch(self):
df = self.spark.range(8, numPartitions=4)
expected = df.collect()
it = df.toLocalIterator(prefetchPartitions=True)
self.assertEqual(expected, list(it))

def test_to_local_iterator_not_fully_consumed(self):
# SPARK-23961: toLocalIterator throws exception when not fully consumed
# Create a DataFrame large enough so that write to socket will eventually block
Expand Down
23 changes: 23 additions & 0 deletions python/pyspark/tests/test_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime, timedelta
import hashlib
import os
import random
import sys
import tempfile
import time
from glob import glob

from py4j.protocol import Py4JJavaError
Expand Down Expand Up @@ -68,6 +70,27 @@ def test_to_localiterator(self):
it2 = rdd2.toLocalIterator()
self.assertEqual([1, 2, 3], sorted(it2))

def test_to_localiterator_prefetch(self):
# Test that we fetch the next partition in parallel
# We do this by returning the current time and:
# reading the first elem, waiting, and reading the second elem
# If not in parallel then these would be at different times
# But since they are being computed in parallel we see the time
# is "close enough" to the same.
rdd = self.sc.parallelize(range(2), 2)
times1 = rdd.map(lambda x: datetime.now())
times2 = rdd.map(lambda x: datetime.now())
timesIterPrefetch = times1.toLocalIterator(prefetchPartitions=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we stick to underscore naming rule?

timesIter = times2.toLocalIterator(prefetchPartitions=False)
timesPrefetchHead = next(timesIterPrefetch)
timesHead = next(timesIter)
time.sleep(2)
timesNext = next(timesIter)
timesPrefetchNext = next(timesIterPrefetch)
print("With prefetch times are: " + str(timesPrefetchHead) + "," + str(timesPrefetchNext))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we remove print?

self.assertTrue(timesNext - timesHead >= timedelta(seconds=2))
self.assertTrue(timesPrefetchNext - timesPrefetchHead < timedelta(seconds=1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty clever test! Anything with timings make me a bit worried about flakiness, but I don't have any other idea how to test this.. Is it possible to see if the jobs were scheduled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could if we used a fresh SparkContext but with the reused context I'm not sure how I'd know if the job was run or not.


def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3387,9 +3387,9 @@ class Dataset[T] private[sql](
}
}

private[sql] def toPythonIterator(): Array[Any] = {
private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions)
}
}

Expand Down