From 0937158f40b9113531946c1c37d0ada488ec418d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 15 Aug 2019 13:53:52 -0700 Subject: [PATCH 01/11] Start working on allowing toLocalIter to prefetch in Python --- .../apache/spark/api/python/PythonRDD.scala | 21 +++++++++++++++---- python/pyspark/rdd.py | 7 +++++-- python/pyspark/sql/dataframe.py | 5 ++++- python/pyspark/sql/tests/test_dataframe.py | 6 ++++++ python/pyspark/tests/test_rdd.py | 19 +++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 4 ++-- 6 files changed, 53 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4d76ff76e675..e2fa3a8f9044 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.duration.Duration import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration @@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, the secret for authentication, and a socket auth * server object that can be used to join the JVM serving thread in Python. */ - def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { + def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = { val handleFunc = (sock: Socket) => { val out = new DataOutputStream(sock.getOutputStream) val in = new DataInputStream(sock.getInputStream) Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Collects a partition on each iteration val collectPartitionIter = rdd.partitions.indices.iterator.map { i => - rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head + var result: Array[Any] = null + rdd.sparkContext.submitJob( + rdd, + (iter: Iterator[Any]) => iter.toArray, + Seq(i), // The partition we are evaluating + (_, res: Array[Any]) => result = res, + result) } + val prefetchIter = collectPartitionIter.buffered // Write data until iteration is complete, client stops iteration, or error occurs var complete = false @@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging { // Read request for data, value of zero will stop iteration or non-zero to continue if (in.readInt() == 0) { complete = true - } else if (collectPartitionIter.hasNext) { + } else if (prefetchIter.hasNext) { + + // Cause the next job to be submitted if prefecthPartitions is enabled. + if (prefetchPartitions) { + prefetchIter.headOption + } // Client requested more data, attempt to collect the next partition - val partitionArray = collectPartitionIter.next() + val partitionArray = ThreadUtils.awaitResult(collectPartitionIter.next(), Duration.Inf) // Send response there is a partition to read out.writeInt(1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 590e8e1e9c07..aab33cf8121c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2444,17 +2444,20 @@ def countApproxDistinct(self, relativeSD=0.05): hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) - def toLocalIterator(self): + def toLocalIterator(self, prefetchPartitions=False): """ Return an iterator that contains all of the elements in this RDD. The iterator will consume as much memory as the largest partition in this RDD. + With prefetch it may consume up to the memory of the 2 largest partitions. + :param prefetchPartitions: If Spark should pre-fetch the next partition + before it is needed. >>> rdd = sc.parallelize(range(10)) >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd(), prefetchPartitions) return _local_iterator_from_socket(sock_info, self._jrdd_deserializer) def barrier(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 398471234d2b..1aefd10ad891 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -520,11 +520,14 @@ def collect(self): @ignore_unicode_prefix @since(2.0) - def toLocalIterator(self): + def toLocalIterator(self, prefetchPartitions=False): """ Returns an iterator that contains all of the rows in this :class:`DataFrame`. The iterator will consume as much memory as the largest partition in this DataFrame. + With prefetch it may consume up to the memory of the 2 largest partitions. + :param prefetchPartitions: If Spark should pre-fetch the next partition + before it is needed. >>> list(df.toLocalIterator()) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 5550a093bf80..df3fb4da65f2 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -689,6 +689,12 @@ def test_to_local_iterator(self): expected = df.collect() self.assertEqual(expected, list(it)) + def test_to_local_iterator_prefetch(self): + df = self.spark.range(8, numPartitions=4) + expected = df.collect() + it = df.toLocalIterator(prefetchPartitions=True) + self.assertEqual(expected, list(it)) + def test_to_local_iterator_not_fully_consumed(self): # SPARK-23961: toLocalIterator throws exception when not fully consumed # Create a DataFrame large enough so that write to socket will eventually block diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index bff080362085..13a521e2fa0e 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -68,6 +68,25 @@ def test_to_localiterator(self): it2 = rdd2.toLocalIterator() self.assertEqual([1, 2, 3], sorted(it2)) + def test_to_localiterator_prefetch(self): + # Test that we fetch the next partition in parallel + # We do this by returning the current time and: + # reading the first elem, waiting, and reading the second elem + # If not in parallel then these would be at different times + # But since they are being computed in parallel we see the time + # is "close enough" to the same. + rdd = self.sc.parallelize(range(10), 10) + times1 = rdd.map(lambda x: time.now) + times2 = rdd.map(lambda x: time.now) + timesIterPrefetch = times1.toLocalIterator(prefetchPartitions=True) + timesIter = times1.toLocalIterator(prefetchPartitions=False) + timesPrefetchHead = timesIterPrefetch.next() + timesHead = timesIter.next() + time.sleep(2) + timesPrefetchNext = timesIterPrefetch.next() + timesNext = timesIter.next() + # TODO write the asserts + def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 87f4c8f5d949..dc5670c78c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3387,9 +3387,9 @@ class Dataset[T] private[sql]( } } - private[sql] def toPythonIterator(): Array[Any] = { + private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = { withNewExecutionId { - PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) + PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions) } } From 77cab47eb33e78c2ac9cc64b861d7a8a90a484d3 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 16 Aug 2019 15:14:38 -0700 Subject: [PATCH 02/11] Move the partitionArray blocking up above the peak at head so we are looking at the next elem not the one we are about to block on, and fix the Python tests. --- .../apache/spark/api/python/PythonRDD.scala | 6 ++--- python/pyspark/tests/test_rdd.py | 22 +++++++++++-------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index e2fa3a8f9044..1b4ab82daa72 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -206,14 +206,14 @@ private[spark] object PythonRDD extends Logging { complete = true } else if (prefetchIter.hasNext) { + // Client requested more data, attempt to collect the next partition + val partitionArray = ThreadUtils.awaitResult(prefetchIter.next(), Duration.Inf) + // Cause the next job to be submitted if prefecthPartitions is enabled. if (prefetchPartitions) { prefetchIter.headOption } - // Client requested more data, attempt to collect the next partition - val partitionArray = ThreadUtils.awaitResult(collectPartitionIter.next(), Duration.Inf) - // Send response there is a partition to read out.writeInt(1) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 13a521e2fa0e..e21594eec85a 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from datetime import datetime, timedelta import hashlib import os import random import sys import tempfile +import time from glob import glob from py4j.protocol import Py4JJavaError @@ -75,17 +77,19 @@ def test_to_localiterator_prefetch(self): # If not in parallel then these would be at different times # But since they are being computed in parallel we see the time # is "close enough" to the same. - rdd = self.sc.parallelize(range(10), 10) - times1 = rdd.map(lambda x: time.now) - times2 = rdd.map(lambda x: time.now) + rdd = self.sc.parallelize(range(2), 2) + times1 = rdd.map(lambda x: datetime.now()) + times2 = rdd.map(lambda x: datetime.now()) timesIterPrefetch = times1.toLocalIterator(prefetchPartitions=True) - timesIter = times1.toLocalIterator(prefetchPartitions=False) - timesPrefetchHead = timesIterPrefetch.next() - timesHead = timesIter.next() + timesIter = times2.toLocalIterator(prefetchPartitions=False) + timesPrefetchHead = next(timesIterPrefetch) + timesHead = next(timesIter) time.sleep(2) - timesPrefetchNext = timesIterPrefetch.next() - timesNext = timesIter.next() - # TODO write the asserts + timesNext = next(timesIter) + timesPrefetchNext = next(timesIterPrefetch) + print("With prefetch times are: " + str(timesPrefetchHead) + "," + str(timesPrefetchNext)) + self.assertTrue(timesNext - timesHead >= timedelta(seconds=2)) + self.assertTrue(timesPrefetchNext - timesPrefetchHead < timedelta(seconds=2)) def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 From 4fc6db92c25866b068fd3d2610e461a029b38029 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 16 Aug 2019 16:59:23 -0700 Subject: [PATCH 03/11] Fix python long line --- python/pyspark/rdd.py | 4 +++- python/pyspark/tests/test_rdd.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index aab33cf8121c..11d482815a9c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2457,7 +2457,9 @@ def toLocalIterator(self, prefetchPartitions=False): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd(), prefetchPartitions) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe( + self._jrdd.rdd(), + prefetchPartitions) return _local_iterator_from_socket(sock_info, self._jrdd_deserializer) def barrier(self): diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index e21594eec85a..652fbd290bef 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -89,7 +89,7 @@ def test_to_localiterator_prefetch(self): timesPrefetchNext = next(timesIterPrefetch) print("With prefetch times are: " + str(timesPrefetchHead) + "," + str(timesPrefetchNext)) self.assertTrue(timesNext - timesHead >= timedelta(seconds=2)) - self.assertTrue(timesPrefetchNext - timesPrefetchHead < timedelta(seconds=2)) + self.assertTrue(timesPrefetchNext - timesPrefetchHead < timedelta(seconds=1)) def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 From b39a83c5b66fc6ee3950f490927729090311302e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 16 Aug 2019 17:00:42 -0700 Subject: [PATCH 04/11] Pull the head off & peak at the head+1 elem while before we block on the head --- bin/spark-shell | 1 - .../main/scala/org/apache/spark/api/python/PythonRDD.scala | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bin/spark-shell b/bin/spark-shell index e92013797498..97444bc80957 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -95,4 +95,3 @@ main "$@" # then reenable echo and propagate the code. exit_status=$? onExit - diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 1b4ab82daa72..3b9a6285b0ad 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -207,12 +207,12 @@ private[spark] object PythonRDD extends Logging { } else if (prefetchIter.hasNext) { // Client requested more data, attempt to collect the next partition - val partitionArray = ThreadUtils.awaitResult(prefetchIter.next(), Duration.Inf) - + val partitionFuture = prefetchIter.next() // Cause the next job to be submitted if prefecthPartitions is enabled. if (prefetchPartitions) { prefetchIter.headOption } + val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf) // Send response there is a partition to read out.writeInt(1) From e0b3871471fc475c4fa07351bb640252c8e2e6ff Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 16 Aug 2019 17:01:49 -0700 Subject: [PATCH 05/11] Add a micro benchmark in prefetch --- examples/src/main/python/prefetch.py | 86 ++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 examples/src/main/python/prefetch.py diff --git a/examples/src/main/python/prefetch.py b/examples/src/main/python/prefetch.py new file mode 100644 index 000000000000..23f04f854f14 --- /dev/null +++ b/examples/src/main/python/prefetch.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys +import timeit +from random import random +from operator import add + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: prefetch [partitions] [iterations] + Uses timeit to demonstrate the benefit of prefetch. + """ + spark = SparkSession\ + .builder\ + .appName("PrefetchDemo")\ + .getOrCreate() + + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 4 + iterations = int(sys.argv[2]) if len(sys.argv) > 2 else 4 + + elems = spark.sparkContext.parallelize(range(1, partitions * 2), partitions) + elems.cache() + elems.count() + + def slowCompute(elem): + """ + Wait ten seconds to simulate some computation then return. + """ + import time + time.sleep(10) + return elem + + def localCompute(elem): + """ + Simulate processing the data locally. + """ + import time + time.sleep(1) + return elem + + def fetchWithPrefetch(): + prefetchIter = elems.mapPartitions(slowCompute).toLocalIterator(prefetchPartitions=True) + localCollection = list(map(localCompute, prefetchIter)) + return localCollection + + def fetchRegular(): + regularIter = elems.mapPartitions(slowCompute).toLocalIterator(prefetchPartitions=False) + localCollection = list(map(localCompute, regularIter)) + return localCollection + + print("Running timers:\n") + prefetchTimer = timeit.Timer(fetchWithPrefetch) + prefetchTime = prefetchTimer.timeit(number=iterations) + + regularTimer = timeit.Timer(fetchRegular) + regularTime = regularTimer.timeit(number=iterations) + print("\nResults:\n") + print("Prefetch time:\n") + print(prefetchTime) + print("\n") + + print("Regular time:\n") + print(regularTime) + print("\n") + + spark.stop() From a060214580e49818de2b7613a119e2ddce3ef782 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 16 Aug 2019 17:18:38 -0700 Subject: [PATCH 06/11] accidental line change we don't need --- bin/spark-shell | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/spark-shell b/bin/spark-shell index 97444bc80957..29c5576253a2 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -95,3 +95,4 @@ main "$@" # then reenable echo and propagate the code. exit_status=$? onExit +\n From c477fec27618acea2864dbb24fda58b5736af86b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 16 Aug 2019 17:20:53 -0700 Subject: [PATCH 07/11] oops on \n --- bin/spark-shell | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-shell b/bin/spark-shell index 29c5576253a2..e92013797498 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -95,4 +95,4 @@ main "$@" # then reenable echo and propagate the code. exit_status=$? onExit -\n + From 6dc47488dc72902cbbf783e4ba6ee2a5166382e5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 20 Aug 2019 17:36:04 -0700 Subject: [PATCH 08/11] Fix missing call --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1aefd10ad891..934f646ed5d6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -532,7 +532,7 @@ def toLocalIterator(self, prefetchPartitions=False): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - sock_info = self._jdf.toPythonIterator() + sock_info = self._jdf.toPythonIterator(prefetchPartitions) return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix From e0327a24e48c4ba7a483ef2590c9dd72c6bedfc5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 20 Aug 2019 17:39:12 -0700 Subject: [PATCH 09/11] Fix sphinx build issues --- python/pyspark/rdd.py | 1 + python/pyspark/sql/dataframe.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 11d482815a9c..17f159053900 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2452,6 +2452,7 @@ def toLocalIterator(self, prefetchPartitions=False): :param prefetchPartitions: If Spark should pre-fetch the next partition before it is needed. + >>> rdd = sc.parallelize(range(10)) >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 934f646ed5d6..03b37fa7d0d9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -528,6 +528,7 @@ def toLocalIterator(self, prefetchPartitions=False): :param prefetchPartitions: If Spark should pre-fetch the next partition before it is needed. + >>> list(df.toLocalIterator()) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ From 11d6688b659c38712085681b6aef85992807ba3b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 10 Sep 2019 11:52:30 -0700 Subject: [PATCH 10/11] Cleanup the tests and some typos --- .../org/apache/spark/api/python/PythonRDD.scala | 2 +- python/pyspark/tests/test_rdd.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 3b9a6285b0ad..7cbfb71beea3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -208,7 +208,7 @@ private[spark] object PythonRDD extends Logging { // Client requested more data, attempt to collect the next partition val partitionFuture = prefetchIter.next() - // Cause the next job to be submitted if prefecthPartitions is enabled. + // Cause the next job to be submitted if prefetchPartitions is enabled. if (prefetchPartitions) { prefetchIter.headOption } diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 652fbd290bef..e7a7971dfc9a 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -80,16 +80,15 @@ def test_to_localiterator_prefetch(self): rdd = self.sc.parallelize(range(2), 2) times1 = rdd.map(lambda x: datetime.now()) times2 = rdd.map(lambda x: datetime.now()) - timesIterPrefetch = times1.toLocalIterator(prefetchPartitions=True) - timesIter = times2.toLocalIterator(prefetchPartitions=False) - timesPrefetchHead = next(timesIterPrefetch) - timesHead = next(timesIter) + times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True) + times_iter = times2.toLocalIterator(prefetchPartitions=False) + times_prefetch_head = next(times_iter_prefetch) + times_head = next(times_iter) time.sleep(2) - timesNext = next(timesIter) - timesPrefetchNext = next(timesIterPrefetch) - print("With prefetch times are: " + str(timesPrefetchHead) + "," + str(timesPrefetchNext)) - self.assertTrue(timesNext - timesHead >= timedelta(seconds=2)) - self.assertTrue(timesPrefetchNext - timesPrefetchHead < timedelta(seconds=1)) + times_next = next(times_iter) + times_prefetch_next = next(times_iter_prefetch) + self.assertTrue(times_next - times_head >= timedelta(seconds=2)) + self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1)) def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 From f8e67f30f03437cdac248695ee75d42e6f96df66 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 10 Sep 2019 11:54:46 -0700 Subject: [PATCH 11/11] Remove the prefetch example we used as a benchmark --- examples/src/main/python/prefetch.py | 86 ---------------------------- 1 file changed, 86 deletions(-) delete mode 100644 examples/src/main/python/prefetch.py diff --git a/examples/src/main/python/prefetch.py b/examples/src/main/python/prefetch.py deleted file mode 100644 index 23f04f854f14..000000000000 --- a/examples/src/main/python/prefetch.py +++ /dev/null @@ -1,86 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys -import timeit -from random import random -from operator import add - -from pyspark.sql import SparkSession - - -if __name__ == "__main__": - """ - Usage: prefetch [partitions] [iterations] - Uses timeit to demonstrate the benefit of prefetch. - """ - spark = SparkSession\ - .builder\ - .appName("PrefetchDemo")\ - .getOrCreate() - - partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 4 - iterations = int(sys.argv[2]) if len(sys.argv) > 2 else 4 - - elems = spark.sparkContext.parallelize(range(1, partitions * 2), partitions) - elems.cache() - elems.count() - - def slowCompute(elem): - """ - Wait ten seconds to simulate some computation then return. - """ - import time - time.sleep(10) - return elem - - def localCompute(elem): - """ - Simulate processing the data locally. - """ - import time - time.sleep(1) - return elem - - def fetchWithPrefetch(): - prefetchIter = elems.mapPartitions(slowCompute).toLocalIterator(prefetchPartitions=True) - localCollection = list(map(localCompute, prefetchIter)) - return localCollection - - def fetchRegular(): - regularIter = elems.mapPartitions(slowCompute).toLocalIterator(prefetchPartitions=False) - localCollection = list(map(localCompute, regularIter)) - return localCollection - - print("Running timers:\n") - prefetchTimer = timeit.Timer(fetchWithPrefetch) - prefetchTime = prefetchTimer.timeit(number=iterations) - - regularTimer = timeit.Timer(fetchRegular) - regularTime = regularTimer.timeit(number=iterations) - print("\nResults:\n") - print("Prefetch time:\n") - print(prefetchTime) - print("\n") - - print("Regular time:\n") - print(regularTime) - print("\n") - - spark.stop()