-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-27659][PYTHON] Allow PySpark to prefetch during toLocalIterator #25515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
0937158
77cab47
4fc6db9
b39a83c
e0b3871
a060214
c477fec
6dc4748
e0327a2
11d6688
f8e67f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} | |
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable | ||
| import scala.concurrent.duration.Duration | ||
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.hadoop.conf.Configuration | ||
|
|
@@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging { | |
| * data collected from this job, the secret for authentication, and a socket auth | ||
| * server object that can be used to join the JVM serving thread in Python. | ||
| */ | ||
| def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { | ||
| def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = { | ||
| val handleFunc = (sock: Socket) => { | ||
| val out = new DataOutputStream(sock.getOutputStream) | ||
| val in = new DataInputStream(sock.getInputStream) | ||
| Utils.tryWithSafeFinallyAndFailureCallbacks(block = { | ||
| // Collects a partition on each iteration | ||
| val collectPartitionIter = rdd.partitions.indices.iterator.map { i => | ||
| rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head | ||
| var result: Array[Any] = null | ||
| rdd.sparkContext.submitJob( | ||
| rdd, | ||
| (iter: Iterator[Any]) => iter.toArray, | ||
| Seq(i), // The partition we are evaluating | ||
| (_, res: Array[Any]) => result = res, | ||
| result) | ||
| } | ||
| val prefetchIter = collectPartitionIter.buffered | ||
|
|
||
| // Write data until iteration is complete, client stops iteration, or error occurs | ||
| var complete = false | ||
|
|
@@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging { | |
| // Read request for data, value of zero will stop iteration or non-zero to continue | ||
| if (in.readInt() == 0) { | ||
| complete = true | ||
| } else if (collectPartitionIter.hasNext) { | ||
| } else if (prefetchIter.hasNext) { | ||
|
|
||
| // Client requested more data, attempt to collect the next partition | ||
| val partitionArray = collectPartitionIter.next() | ||
| val partitionFuture = prefetchIter.next() | ||
| // Cause the next job to be submitted if prefecthPartitions is enabled. | ||
| if (prefetchPartitions) { | ||
| prefetchIter.headOption | ||
| } | ||
| val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be best to avoid var next = collectPartitionIter.next()
val prefetchIter = collectPartitionIter.map { part =>
val tmp = next
next = part
tmp
} ++ Iterator(next)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the awaitFuture (or something similar) is required for us to use futures. If we just used a buffered iterator without allowing the job to schedule separately we'd just block for both partitions right away instead of evaluating the other future in the background while we block on the first. (Implicitly this awaitResult is already effectively done inside of the previous DAGScheduler's runJob.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, you are totally right. That would block while getting the prefetched partition. This looks pretty good to me then. One question though, when should the first job be triggered? I think the old behavior used to start the first job as soon as
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In either case it waits for reading a request of data from the Python side before starting a job, because the map on the partition indices is lazily evaluated. |
||
|
|
||
| // Send response there is a partition to read | ||
| out.writeInt(1) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| # | ||
|
||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from __future__ import print_function | ||
|
|
||
| import sys | ||
| import timeit | ||
| from random import random | ||
| from operator import add | ||
|
|
||
| from pyspark.sql import SparkSession | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| """ | ||
| Usage: prefetch [partitions] [iterations] | ||
| Uses timeit to demonstrate the benefit of prefetch. | ||
| """ | ||
| spark = SparkSession\ | ||
| .builder\ | ||
| .appName("PrefetchDemo")\ | ||
| .getOrCreate() | ||
|
|
||
| partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 4 | ||
| iterations = int(sys.argv[2]) if len(sys.argv) > 2 else 4 | ||
|
|
||
| elems = spark.sparkContext.parallelize(range(1, partitions * 2), partitions) | ||
| elems.cache() | ||
| elems.count() | ||
|
|
||
| def slowCompute(elem): | ||
| """ | ||
| Wait ten seconds to simulate some computation then return. | ||
| """ | ||
| import time | ||
| time.sleep(10) | ||
| return elem | ||
|
|
||
| def localCompute(elem): | ||
| """ | ||
| Simulate processing the data locally. | ||
| """ | ||
| import time | ||
| time.sleep(1) | ||
| return elem | ||
|
|
||
| def fetchWithPrefetch(): | ||
| prefetchIter = elems.mapPartitions(slowCompute).toLocalIterator(prefetchPartitions=True) | ||
| localCollection = list(map(localCompute, prefetchIter)) | ||
| return localCollection | ||
|
|
||
| def fetchRegular(): | ||
| regularIter = elems.mapPartitions(slowCompute).toLocalIterator(prefetchPartitions=False) | ||
| localCollection = list(map(localCompute, regularIter)) | ||
| return localCollection | ||
|
|
||
| print("Running timers:\n") | ||
| prefetchTimer = timeit.Timer(fetchWithPrefetch) | ||
| prefetchTime = prefetchTimer.timeit(number=iterations) | ||
|
|
||
| regularTimer = timeit.Timer(fetchRegular) | ||
| regularTime = regularTimer.timeit(number=iterations) | ||
| print("\nResults:\n") | ||
| print("Prefetch time:\n") | ||
| print(prefetchTime) | ||
| print("\n") | ||
|
|
||
| print("Regular time:\n") | ||
| print(regularTime) | ||
| print("\n") | ||
|
|
||
| spark.stop() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,11 +14,13 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| from datetime import datetime, timedelta | ||
| import hashlib | ||
| import os | ||
| import random | ||
| import sys | ||
| import tempfile | ||
| import time | ||
| from glob import glob | ||
|
|
||
| from py4j.protocol import Py4JJavaError | ||
|
|
@@ -68,6 +70,27 @@ def test_to_localiterator(self): | |
| it2 = rdd2.toLocalIterator() | ||
| self.assertEqual([1, 2, 3], sorted(it2)) | ||
|
|
||
| def test_to_localiterator_prefetch(self): | ||
| # Test that we fetch the next partition in parallel | ||
| # We do this by returning the current time and: | ||
| # reading the first elem, waiting, and reading the second elem | ||
| # If not in parallel then these would be at different times | ||
| # But since they are being computed in parallel we see the time | ||
| # is "close enough" to the same. | ||
| rdd = self.sc.parallelize(range(2), 2) | ||
| times1 = rdd.map(lambda x: datetime.now()) | ||
| times2 = rdd.map(lambda x: datetime.now()) | ||
| timesIterPrefetch = times1.toLocalIterator(prefetchPartitions=True) | ||
|
||
| timesIter = times2.toLocalIterator(prefetchPartitions=False) | ||
| timesPrefetchHead = next(timesIterPrefetch) | ||
| timesHead = next(timesIter) | ||
| time.sleep(2) | ||
| timesNext = next(timesIter) | ||
| timesPrefetchNext = next(timesIterPrefetch) | ||
| print("With prefetch times are: " + str(timesPrefetchHead) + "," + str(timesPrefetchNext)) | ||
|
||
| self.assertTrue(timesNext - timesHead >= timedelta(seconds=2)) | ||
| self.assertTrue(timesPrefetchNext - timesPrefetchHead < timedelta(seconds=1)) | ||
|
||
|
|
||
| def test_save_as_textfile_with_unicode(self): | ||
| # Regression test for SPARK-970 | ||
| x = u"\u00A1Hola, mundo!" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo:
prefecthPartitions->prefetchPartitions