Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ private[spark] object PythonRDD extends Logging {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}

def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val length = file.readInt()
val obj = new Array[Byte](length)
file.readFully(obj)
sc.broadcast(obj)
}

def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the
Expand Down
37 changes: 28 additions & 9 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@
>>> b = sc.broadcast([1, 2, 3, 4, 5])
>>> b.value
[1, 2, 3, 4, 5]

>>> from pyspark.broadcast import _broadcastRegistry
>>> _broadcastRegistry[b.bid] = b
>>> from cPickle import dumps, loads
>>> loads(dumps(b)).value
[1, 2, 3, 4, 5]

>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call here; it was a bad idea to expose these internals in user-facing module doctests.

>>> b.unpersist()

>>> large_broadcast = sc.broadcast(list(range(10000)))
"""
import os

from pyspark.serializers import CompressedSerializer, PickleSerializer

# Holds broadcasted data received from Java, keyed by its id.
_broadcastRegistry = {}

Expand All @@ -52,17 +50,38 @@ class Broadcast(object):
Access its value through C{.value}.
"""

def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
def __init__(self, bid, value, java_broadcast=None,
pickle_registry=None, path=None):
"""
Should not be called directly by users -- use
L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}
instead.
"""
self.value = value
self.bid = bid
if path is None:
self.value = value
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
self.path = path

def unpersist(self, blocking=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a docstring? It's fine to just copy it over from the Scala equivalent. In this case:

  /**
   * Delete cached copies of this broadcast on the executors. If the broadcast is used after
   * this is called, it will need to be re-sent to each executor.
   * @param blocking Whether to block until unpersisting has completed
   */

self._jbroadcast.unpersist(blocking)
os.unlink(self.path)

def __reduce__(self):
self._pickle_registry.add(self)
return (_from_id, (self.bid, ))

def __getattr__(self, item):
if item == 'value' and self.path is not None:
ser = CompressedSerializer(PickleSerializer())
value = ser.load_stream(open(self.path)).next()
self.value = value
return value

raise AttributeError(item)


if __name__ == "__main__":
import doctest
doctest.testmod()
20 changes: 13 additions & 7 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer
PairDeserializer, CompressedSerializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
Expand Down Expand Up @@ -566,13 +566,19 @@ def broadcast(self, value):
"""
Broadcast a read-only variable to the cluster, returning a
L{Broadcast<pyspark.broadcast.Broadcast>}
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
object for reading it in distributed functions. The variable will
be sent to each cluster only once.

:keep: Keep the `value` in driver or not.
"""
pickleSer = PickleSerializer()
pickled = pickleSer.dumps(value)
jbroadcast = self._jsc.broadcast(bytearray(pickled))
return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars)
ser = CompressedSerializer(PickleSerializer())
# pass large object by py4j is very slow and need much memory
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
ser.dump_stream([value], tempFile)
tempFile.close()
jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name)
return Broadcast(jbroadcast.id(), None, jbroadcast,
self._pickled_broadcast_vars, tempFile.name)

def accumulator(self, value, accum_param=None):
"""
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long
PickleSerializer, pack_long, CompressedSerializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
Expand Down Expand Up @@ -1809,7 +1809,8 @@ def _jrdd(self):
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
pickled_command = CloudPickleSerializer().dumps(command)
ser = CompressedSerializer(CloudPickleSerializer())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea. It wouldn't surprise me if the pickle data was highly compressible due to frequently-occuring groups of pickle opcodes, etc.

pickled_command = ser.dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import sys
import types
import collections
import zlib

from pyspark import cloudpickle

Expand Down Expand Up @@ -403,6 +404,22 @@ def loads(self, obj):
raise ValueError("invalid sevialization type: %s" % _type)


class CompressedSerializer(FramedSerializer):
"""
compress the serialized data
"""

def __init__(self, serializer):
FramedSerializer.__init__(self)
self.serializer = serializer

def dumps(self, obj):
return zlib.compress(self.serializer.dumps(obj), 1)

def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))


class UTF8Deserializer(Serializer):

"""
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,13 @@ def test_namedtuple_in_rdd(self):
theDoes = self.sc.parallelize([jon, jane])
self.assertEquals([jon, jane], theDoes.collect())

def test_large_broadcast(self):
N = 100000
data = [[float(i) for i in range(300)] for i in range(N)]
bdata = self.sc.broadcast(data) # 270MB
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)


class TestIO(PySparkTestCase):

Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer


pickleSer = PickleSerializer()
Expand Down Expand Up @@ -65,12 +66,13 @@ def main(infile, outfile):

# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = pickleSer._read_with_length(infile)
value = ser._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)

command = pickleSer._read_with_length(infile)
command = ser._read_with_length(infile)
(func, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)
Expand Down