-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-29434][Core] Improve the MapStatuses Serialization Performance #26085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
78bbcf2
9714f39
18e5bda
39502ed
3958c01
a4a807e
d0c3532
ed08f2e
a601356
8dc8fad
0bf182a
bd88abd
f184c4c
bc6a14c
5aadd8f
d7fce82
7095b60
08c8fb2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,9 @@ | |
|
|
||
| package org.apache.spark | ||
|
|
||
| import java.io._ | ||
| import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} | ||
| import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} | ||
| import java.util.concurrent.locks.ReentrantReadWriteLock | ||
| import java.util.zip.{GZIPInputStream, GZIPOutputStream} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable.{HashMap, ListBuffer, Map} | ||
|
|
@@ -29,6 +28,10 @@ import scala.concurrent.duration.Duration | |
| import scala.reflect.ClassTag | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import com.github.luben.zstd.ZstdInputStream | ||
| import com.github.luben.zstd.ZstdOutputStream | ||
| import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream} | ||
|
|
||
| import org.apache.spark.broadcast.{Broadcast, BroadcastManager} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.internal.config._ | ||
|
|
@@ -885,13 +888,18 @@ private[spark] object MapOutputTracker extends Logging { | |
| private val BROADCAST = 1 | ||
|
|
||
| // Serialize an array of map output locations into an efficient byte format so that we can send | ||
| // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will | ||
| // it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will | ||
| // generally be pretty compressible because many map outputs will be on the same hostname. | ||
| def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager, | ||
| isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = { | ||
| val out = new ByteArrayOutputStream | ||
| out.write(DIRECT) | ||
| val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) | ||
| // Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard one | ||
| // This implementation doesn't reallocate the whole memory block but allocates | ||
| // additional buffers. This way no buffers need to be garbage collected and | ||
| // the contents don't have to be copied to the new buffer. | ||
| val out = new ApacheByteArrayOutputStream() | ||
| val compressedOut = new ApacheByteArrayOutputStream() | ||
|
|
||
| val objOut = new ObjectOutputStream(out) | ||
| Utils.tryWithSafeFinally { | ||
| // Since statuses can be modified in parallel, sync on it | ||
| statuses.synchronized { | ||
|
|
@@ -900,18 +908,42 @@ private[spark] object MapOutputTracker extends Logging { | |
| } { | ||
| objOut.close() | ||
| } | ||
| val arr = out.toByteArray | ||
|
|
||
| val arr: Array[Byte] = { | ||
| val zos = new ZstdOutputStream(compressedOut) | ||
| Utils.tryWithSafeFinally { | ||
| compressedOut.write(DIRECT) | ||
| // `out.writeTo(zos)` will write the uncompressed data from `out` to `zos` | ||
| // without copying to avoid unnecessary allocation and copy of byte[]. | ||
| out.writeTo(zos) | ||
| } { | ||
| zos.close() | ||
| } | ||
| compressedOut.toByteArray | ||
| } | ||
| if (arr.length >= minBroadcastSize) { | ||
| // Use broadcast instead. | ||
| // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! | ||
| val bcast = broadcastManager.newBroadcast(arr, isLocal) | ||
| // toByteArray creates copy, so we can reuse out | ||
dongjoon-hyun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| out.reset() | ||
| out.write(BROADCAST) | ||
| val oos = new ObjectOutputStream(new GZIPOutputStream(out)) | ||
| oos.writeObject(bcast) | ||
| oos.close() | ||
| val outArr = out.toByteArray | ||
| val oos = new ObjectOutputStream(out) | ||
| Utils.tryWithSafeFinally { | ||
| oos.writeObject(bcast) | ||
| } { | ||
| oos.close() | ||
| } | ||
| val outArr = { | ||
| compressedOut.reset() | ||
| val zos = new ZstdOutputStream(compressedOut) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, @dbtsai , I am back-porting this into our internal repo. Looks like this compression is unnecessary since
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The actually value of the data (which is already compressed) will not be in the serialized form of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your clarification. It's indeed not including the compressed data. |
||
| Utils.tryWithSafeFinally { | ||
| compressedOut.write(BROADCAST) | ||
| out.writeTo(zos) | ||
| } { | ||
| zos.close() | ||
| } | ||
| compressedOut.toByteArray | ||
| } | ||
| logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) | ||
| (outArr, bcast) | ||
| } else { | ||
|
|
@@ -924,7 +956,7 @@ private[spark] object MapOutputTracker extends Logging { | |
| assert (bytes.length > 0) | ||
|
|
||
| def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = { | ||
| val objIn = new ObjectInputStream(new GZIPInputStream( | ||
| val objIn = new ObjectInputStream(new ZstdInputStream( | ||
| new ByteArrayInputStream(arr, off, len))) | ||
| Utils.tryWithSafeFinally { | ||
| objIn.readObject() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.