Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
779c0f9
initial commit of sort-merge shuffle reader
jerryshao Sep 5, 2014
4f46dc0
Readability improvements to SortShuffleReader
sryza Oct 22, 2014
0861cf9
Clarify mergeWidth logic
sryza Oct 23, 2014
8f49b78
Add blocks remaining at level counter back in
sryza Oct 23, 2014
fcafa16
Small fix
sryza Oct 24, 2014
21dae69
Move merge to a separate class and use a priority queue instead of le…
sryza Oct 25, 2014
8e3766a
Rebase to the latest code and fix some conflicts
jerryshao Oct 30, 2014
98c039b
SortShuffleReader code improvement
jerryshao Nov 4, 2014
7d999ef
Changes to rebase to the latest master branch
jerryshao Nov 5, 2014
319e6d1
Don't spill more blocks than we need to
sryza Nov 5, 2014
96ef5c1
Fix bug: add to inMemoryBlocks
sryza Nov 5, 2014
d481c98
Fix another bug
sryza Nov 5, 2014
bf6a49d
Bug fix and revert ShuffleMemoryManager
jerryshao Nov 5, 2014
79dc823
Fix some bugs in spilling to disk
jerryshao Nov 7, 2014
2e04b85
Modify to use BlockObjectWriter to write data
jerryshao Nov 10, 2014
c1f97b6
Fix incorrect block size introduced bugs
jerryshao Nov 11, 2014
b5e472d
Address the comments
jerryshao Nov 12, 2014
40c59df
Fix some bugs
jerryshao Nov 12, 2014
42bf77d
Improve the failure process and expand ManagedBuffer
jerryshao Nov 14, 2014
a9eaef8
Copy the memory from off-heap to on-heap and some code style modifica…
jerryshao Nov 17, 2014
6f48c5c
Fix rebase introduced issue
jerryshao Nov 18, 2014
c2ddcce
Revert some unwanted changes
jerryshao Nov 18, 2014
f170db3
Clean up comments, break up large methods, spill based on actual bloc…
sryza Nov 24, 2014
123aea1
Log improve
jerryshao Nov 25, 2014
e035105
Fix scala style issue
jerryshao Nov 25, 2014
8b73701
Fix rebase issues
jerryshao Feb 22, 2015
d6c94da
Fix dead lock
jerryshao Apr 13, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
initial commit of sort-merge shuffle reader
Conflicts:
	core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
	core/src/main/scala/org/apache/spark/storage/BlockManager.scala
	core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala

Conflicts:
	core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
  • Loading branch information
jerryshao committed Apr 13, 2015
commit 779c0f95f66f3487adeccaf67f8f3136d2b56c68
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,16 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
}

/** Release numBytes bytes for the current thread. */
def release(numBytes: Long): Unit = synchronized {
val threadId = Thread.currentThread().getId
val curMem = threadMemory.getOrElse(threadId, 0L)
def release(numBytes: Long): Unit = release(numBytes, Thread.currentThread().getId)

/** Release numBytes bytes for the specific thread. */
def release(numBytes: Long, tid: Long): Unit = synchronized {
val curMem = threadMemory.getOrElse(tid, 0L)
if (curMem < numBytes) {
throw new SparkException(
s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
}
threadMemory(threadId) -= numBytes
threadMemory(tid) -= numBytes
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator

private[hash] object BlockStoreShuffleFetcher extends Logging {
private[shuffle] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
shuffleId: Int,
reduceId: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency}
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.hash.HashShuffleReader

private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager {

Expand All @@ -48,7 +47,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
// We currently use the same block store shuffle fetcher as the hash-based shuffle.
new HashShuffleReader(
new SortShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.sort

import java.io.{BufferedOutputStream, FileOutputStream, File}
import java.nio.ByteBuffer
import java.util.Comparator
import java.util.concurrent.{CountDownLatch, TimeUnit, LinkedBlockingQueue}

import org.apache.spark.network.ManagedBuffer

import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark.{Logging, InterruptibleIterator, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleReader, BaseShuffleHandle}
import org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher
import org.apache.spark.storage._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext)
extends ShuffleReader[K, C] with Logging {

require(endPartition == startPartition + 1,
"Sort shuffle currently only supports fetching one partition")

sealed trait ShufflePartition
case class MemoryPartition(blockId: BlockId, blockData: ManagedBuffer) extends ShufflePartition
case class FilePartition(blockId: BlockId, mappedFile: File) extends ShufflePartition

private val mergingGroup = new LinkedBlockingQueue[ShufflePartition]()
private val mergedGroup = new LinkedBlockingQueue[ShufflePartition]()
private var numSplits: Int = 0
private val mergeFinished = new CountDownLatch(1)
private val mergingThread = new MergingThread()
private val tid = Thread.currentThread().getId
private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = null

private val dep = handle.dependency
private val conf = SparkEnv.get.conf
private val blockManager = SparkEnv.get.blockManager
private val ser = Serializer.getSerializer(dep.serializer)
private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager

private val ioSortFactor = conf.getInt("spark.shuffle.ioSortFactor", 100)
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024

private val keyComparator: Comparator[K] = dep.keyOrdering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K) = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
h1 - h2
}
})

override def read(): Iterator[Product2[K, C]] = {
if (!dep.mapSideCombine && dep.aggregator.isDefined) {
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
new InterruptibleIterator(context,
dep.aggregator.get.combineValuesByKey(iter, context))
} else {
sortShuffleRead()
}
}

private def sortShuffleRead(): Iterator[Product2[K, C]] = {
val rawBlockIterator = fetchRawBlock()

mergingThread.setNumSplits(numSplits)
mergingThread.setDaemon(true)
mergingThread.start()

for ((blockId, blockData) <- rawBlockIterator) {
if (blockData.isEmpty) {
throw new IllegalStateException(s"block $blockId is empty for unknown reason")
}

val amountToRequest = blockData.get.size
val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
val shouldSpill = if (granted < amountToRequest) {
shuffleMemoryManager.release(granted)
logInfo(s"Grant memory $granted less than the amount to request $amountToRequest, " +
s"spilling data to file")
true
} else {
false
}

if (!shouldSpill) {
mergingGroup.offer(MemoryPartition(blockId, blockData.get))
} else {
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempBlock()
val channel = new FileOutputStream(file).getChannel()
val byteBuffer = blockData.get.nioByteBuffer()
while (byteBuffer.remaining() > 0) {
channel.write(byteBuffer)
}
channel.close()
mergingGroup.offer(FilePartition(tmpBlockId, file))
}

shuffleRawBlockFetcherItr.currentResult = null
}

mergeFinished.await()

// Merge the final group for combiner to directly feed to the reducer
val finalMergedPartArray = mergedGroup.toArray(new Array[ShufflePartition](mergedGroup.size()))
val finalItrGroup = getIteratorGroup(finalMergedPartArray)
val mergedItr = if (dep.aggregator.isDefined) {
ExternalSorter.mergeWithAggregation(finalItrGroup, dep.aggregator.get.mergeCombiners,
keyComparator, dep.keyOrdering.isDefined)
} else {
ExternalSorter.mergeSort(finalItrGroup, keyComparator)
}

mergedGroup.clear()

// Release the shuffle used memory of this thread
shuffleMemoryManager.releaseMemoryForThisThread()

// Release the in-memory block and on-disk file when iteration is completed.
val completionItr = CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](
mergedItr, releaseUnusedShufflePartition(finalMergedPartArray))

new InterruptibleIterator(context, completionItr.map(p => (p._1, p._2)))
}

override def stop(): Unit = ???

private def fetchRawBlock(): Iterator[(BlockId, Option[ManagedBuffer])] = {
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(handle.shuffleId, startPartition)
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]()
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => (ShuffleBlockId(handle.shuffleId, s._1, startPartition), s._2)))
}
blocksByAddress.foreach { case (_, blocks) =>
blocks.foreach { case (_, len) => if (len > 0) numSplits += 1 }
}
logInfo(s"Fetch $numSplits partitions for $tid")

shuffleRawBlockFetcherItr = new ShuffleRawBlockFetcherIterator(
context,
SparkEnv.get.blockTransferService,
blockManager,
blocksByAddress,
SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)

val completionItr = CompletionIterator[
(BlockId, Option[ManagedBuffer]),
Iterator[(BlockId, Option[ManagedBuffer])]](shuffleRawBlockFetcherItr, {
context.taskMetrics.updateShuffleReadMetrics()
})

new InterruptibleIterator[(BlockId, Option[ManagedBuffer])](context, completionItr)
}

private def getIteratorGroup(shufflePartGroup: Array[ShufflePartition])
: Seq[Iterator[Product2[K, C]]] = {
shufflePartGroup.map { part =>
val itr = part match {
case MemoryPartition(id, buf) =>
// Release memory usage
shuffleMemoryManager.release(buf.size, tid)
blockManager.dataDeserialize(id, buf.nioByteBuffer(), ser)
case FilePartition(id, file) =>
val blockData = blockManager.diskStore.getBytes(id).getOrElse(
throw new IllegalStateException(s"cannot get data from block $id"))
blockManager.dataDeserialize(id, blockData, ser)
}
itr.asInstanceOf[Iterator[Product2[K, C]]]
}.toSeq
}


/**
* Release the left in-memory buffer or on-disk file after merged.
*/
private def releaseUnusedShufflePartition(shufflePartGroup: Array[ShufflePartition]): Unit = {
shufflePartGroup.map { part =>
part match {
case MemoryPartition(id, buf) => buf.release()
case FilePartition(id, file) =>
try {
file.delete()
} catch {
// Swallow the exception
case e: Throwable => logWarning(s"Unexpected errors when deleting file: ${
file.getAbsolutePath}", e)
}
}
}
}

private class MergingThread extends Thread {
private var isLooped = true
private var leftTobeMerged = 0

def setNumSplits(numSplits: Int) {
leftTobeMerged = numSplits
}

override def run() {
while (isLooped) {
if (leftTobeMerged < ioSortFactor && leftTobeMerged > 0) {
var count = leftTobeMerged
while (count > 0) {
val part = mergingGroup.poll(100, TimeUnit.MILLISECONDS)
if (part != null) {
mergedGroup.offer(part)
count -= 1
leftTobeMerged -= 1
}
}
} else if (leftTobeMerged >= ioSortFactor) {
val mergingPartArray = ArrayBuffer[ShufflePartition]()
var count = if (numSplits / ioSortFactor > ioSortFactor) {
ioSortFactor
} else {
val mergedSize = mergedGroup.size()
val left = leftTobeMerged - (ioSortFactor - mergedSize - 1)
if (left <= ioSortFactor) {
left
} else {
ioSortFactor
}
}
val countCopy = count

while (count > 0) {
val part = mergingGroup.poll(100, TimeUnit.MILLISECONDS)
if (part != null) {
mergingPartArray += part
count -= 1
leftTobeMerged -= 1
}
}

// Merge the partitions
val itrGroup = getIteratorGroup(mergingPartArray.toArray)
val partialMergedIter = if (dep.aggregator.isDefined) {
ExternalSorter.mergeWithAggregation(itrGroup, dep.aggregator.get.mergeCombiners,
keyComparator, dep.keyOrdering.isDefined)
} else {
ExternalSorter.mergeSort(itrGroup, keyComparator)
}
// Write merged partitions to disk
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempBlock()
val fos = new BufferedOutputStream(new FileOutputStream(file), fileBufferSize)
blockManager.dataSerializeStream(tmpBlockId, fos, partialMergedIter, ser)
logInfo(s"Merge $countCopy partitions and write into file ${file.getName}")

releaseUnusedShufflePartition(mergingPartArray.toArray)
mergedGroup.add(FilePartition(tmpBlockId, file))
} else {
val mergedSize = mergedGroup.size()
if (mergedSize > ioSortFactor) {
leftTobeMerged = mergedSize

// Swap the merged group and merging group and do merge again,
// since file number is still larger than ioSortFactor
assert(mergingGroup.size() == 0)
mergingGroup.addAll(mergedGroup)
mergedGroup.clear()
} else {
assert(mergingGroup.size() == 0)
isLooped = false
mergeFinished.countDown()
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ private[spark] class SortShuffleWriter[K, V, C](
sorter = new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
sorter.insertAll(records)
} else if (dep.keyOrdering.isDefined) {
sorter = new ExternalSorter[K, V, V](
None, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
sorter.insertAll(records)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
Expand Down
Loading