Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix serialization
  • Loading branch information
sryza committed Apr 28, 2015
commit 8c70dd909c20af6dacaf4091e964a1239a221016
31 changes: 31 additions & 0 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ abstract class SerializerInstance {
*/
@DeveloperApi
abstract class SerializationStream {
/** The most general-purpose method to write an object. */
def writeObject[T: ClassTag](t: T): SerializationStream
/** Writes the object representing the key of a key-value pair. */
def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key)
/** Writes the object representing the value of a key-value pair. */
def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value)
def flush(): Unit
def close(): Unit

Expand All @@ -120,7 +125,12 @@ abstract class SerializationStream {
*/
@DeveloperApi
abstract class DeserializationStream {
/** The most general-purpose method to read an object. */
def readObject[T: ClassTag](): T
/** Reads the object representing the key of a key-value pair. */
def readKey[T: ClassTag](): T = readObject[T]()
/** Reads the object representing the value of a key-value pair. */
def readValue[T: ClassTag](): T = readObject[T]()
def close(): Unit

/**
Expand All @@ -141,4 +151,25 @@ abstract class DeserializationStream {
DeserializationStream.this.close()
}
}

/**
* Read the elements of this stream through an iterator over key-value pairs. This can only be
* called once, as reading each element will consume data from the input source.
*/
def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] {
override protected def getNext() = {
try {
(readKey[Any](), readValue[Any]())
} catch {
case eof: EOFException => {
finished = true
null
}
}
}

override protected def close() {
DeserializationStream.this.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ private[spark] class DiskBlockObjectWriter(
open()
}

objOut.writeObject(key)
objOut.writeObject(value)
objOut.writeKey(key)
objOut.writeValue(value)
numRecordsWritten += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid code duplication, I think we could have called recordWritten() here instead, which would handle updating the bytes written etc.

writeMetrics.incShuffleRecordsWritten(1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
import org.apache.spark.util.{CompletionIterator, Utils}
import org.apache.spark.util.collection.PairIterator

/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
Expand Down Expand Up @@ -300,7 +299,7 @@ final class ShuffleBlockFetcherIterator(
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { is0 =>
val is = blockManager.wrapForCompression(blockId, is0)
val iter = new PairIterator(serializerInstance.deserializeStream(is).asIterator)
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
// so we don't release it again in cleanup.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ class ExternalAppendOnlyMap[K, V, C](
*/
private def readNextItem(): (K, C) = {
try {
val k = deserializeStream.readObject().asInstanceOf[K]
val c = deserializeStream.readObject().asInstanceOf[C]
val k = deserializeStream.readKey().asInstanceOf[K]
val c = deserializeStream.readValue().asInstanceOf[C]
val item = (k, c)
objectsRead += 1
if (objectsRead == serializerBatchSize) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,8 @@ private[spark] class ExternalSorter[K, V, C](
if (finished || deserializeStream == null) {
return null
}
val k = deserializeStream.readObject().asInstanceOf[K]
val c = deserializeStream.readObject().asInstanceOf[C]
val k = deserializeStream.readKey().asInstanceOf[K]
val c = deserializeStream.readValue().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
var metaBufferPos = 0
def hasNext: Boolean = metaBufferPos < metaBuffer.position
def next(): ((Int, K), V) = {
val key = deserStream.readObject[Any]().asInstanceOf[K]
val value = deserStream.readObject[Any]().asInstanceOf[V]
val key = deserStream.readKey[Any]().asInstanceOf[K]
val value = deserStream.readValue[Any]().asInstanceOf[V]
val partition = metaBuffer.get(metaBufferPos + PARTITION)
metaBufferPos += RECORD_SIZE
((partition, key), value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,27 @@ private[sql] class Serializer2SerializationStream(
extends SerializationStream with Logging {

val rowOut = new DataOutputStream(out)
val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)

def writeObject[T: ClassTag](t: T): SerializationStream = {
override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
writeKey(kv._1)
writeValue(kv._2)

this
}

override def writeKey[T: ClassTag](t: T): SerializationStream = {
writeKeyFunc(t.asInstanceOf[Row])
this
}

override def writeValue[T: ClassTag](t: T): SerializationStream = {
writeValueFunc(t.asInstanceOf[Row])
this
}

def flush(): Unit = {
rowOut.flush()
}
Expand All @@ -83,17 +93,27 @@ private[sql] class Serializer2DeserializationStream(

val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)

def readObject[T: ClassTag](): T = {
readKey()
readValue()
override def readObject[T: ClassTag](): T = {
readKeyFunc()
readValueFunc()

(key, value).asInstanceOf[T]
}

def close(): Unit = {
override def readKey[T: ClassTag](): T = {
readKeyFunc()
key.asInstanceOf[T]
}

override def readValue[T: ClassTag](): T = {
readValueFunc()
value.asInstanceOf[T]
}

override def close(): Unit = {
rowIn.close()
}
}
Expand Down