Skip to content
Next Next commit
init commit
  • Loading branch information
lianhuiwang committed Jun 29, 2015
commit f149147f0283d8c0e53bb850b1548ee6d84d8a70
30 changes: 30 additions & 0 deletions core/src/main/java/org/apache/spark/Spillable.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark;

/**
* Force to spill contents of memory buffer to disk and release its memory
*/
public interface Spillable {

/**
* force to spill contents of memory buffer to disk
* @return numBytes bytes of spilled
*/
public long forceSpill();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.shuffle

import scala.collection.mutable

import org.apache.spark.{Logging, SparkException, SparkConf}
import org.apache.spark._

/**
* Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
Expand All @@ -38,8 +38,47 @@ import org.apache.spark.{Logging, SparkException, SparkConf}
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes

// threadId -> memory reserved list
private val threadReservedList = new mutable.HashMap[Long, mutable.ListBuffer[Spillable]]()

def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))

/**
* release other Spillable's memory of current thread until freeMemory >= requestedMemory
*/
def releaseReservedMemory(toGrant: Long, requestedAmount: Long): Long = synchronized {
val threadId = Thread.currentThread().getId
if (toGrant >= requestedAmount || !threadReservedList.contains(threadId)){
toGrant
} else {
//try to spill objs in current thread to make space for new request
var addedMemory = toGrant
while(addedMemory < requestedAmount && !threadReservedList(threadId).isEmpty ) {
val toSpill = threadReservedList(threadId).remove(0)
val spillMemory = toSpill.forceSpill()
logInfo(s"Thread $threadId forceSpill $spillMemory bytes to be free")
addedMemory += spillMemory
}
if (addedMemory > requestedAmount) {
this.release(addedMemory - requestedAmount)
addedMemory = requestedAmount
}
addedMemory
}
}

/**
* add Spillable to memoryReservedList of current thread, when current thread has
* no enough memory, we can release memory of current thread's memory reserved list
*/
def addSpillableToReservedList(spill: Spillable) = synchronized {
val threadId = Thread.currentThread().getId
if (!threadReservedList.contains(threadId)) {
threadReservedList(threadId) = new mutable.ListBuffer[Spillable]()
}
threadReservedList(threadId) += spill
}

/**
* Try to acquire up to numBytes memory for the current thread, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
Expand Down Expand Up @@ -108,6 +147,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
def releaseMemoryForThisThread(): Unit = synchronized {
val threadId = Thread.currentThread().getId
threadMemory.remove(threadId)
threadReservedList.remove(threadId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.util.collection

import org.apache.spark.Logging
import org.apache.spark.SparkEnv
import org.apache.spark.{Logging, SparkEnv, Spillable}

import scala.reflect.ClassTag
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: scala imports before spark imports


/**
* Spills contents of an in-memory collection to disk when the memory threshold
* has been exceeded.
*/
private[spark] trait Spillable[C] extends Logging {
private[spark] trait CollectionSpillable[C] extends Logging with Spillable{
/**
* Spills the current in-memory collection to disk, and releases the memory.
*
Expand All @@ -40,25 +41,25 @@ private[spark] trait Spillable[C] extends Logging {
protected def addElementsRead(): Unit = { _elementsRead += 1 }

// Memory manager that can be used to acquire/release memory
private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
protected val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager

// Initial threshold for the size of a collection before we start tracking its memory usage
// Exposed for testing
private[this] val initialMemoryThreshold: Long =
protected val initialMemoryThreshold: Long =
SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)

// Threshold for this collection's size in bytes before we start tracking its memory usage
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
private[this] var myMemoryThreshold = initialMemoryThreshold
protected var myMemoryThreshold = initialMemoryThreshold

// Number of elements read from input since last spill
private[this] var _elementsRead = 0L
protected var _elementsRead = 0L

// Number of bytes spilled in total
private[this] var _memoryBytesSpilled = 0L
protected var _memoryBytesSpilled = 0L

// Number of spills
private[this] var _spillCount = 0
protected var _spillCount = 0

/**
* Spills the current in-memory collection to disk if needed. Attempts to acquire more
Expand Down Expand Up @@ -111,7 +112,7 @@ private[spark] trait Spillable[C] extends Logging {
*
* @param size number of bytes spilled
*/
@inline private def logSpillage(size: Long) {
@inline protected def logSpillage(size: Long) {
val threadId = Thread.currentThread().getId
logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)"
.format(threadId, org.apache.spark.util.Utils.bytesToString(size),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ExternalAppendOnlyMap[K, V, C](
extends Iterable[(K, C)]
with Serializable
with Logging
with Spillable[SizeTracker] {
with CollectionSpillable[SizeTracker] {

private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
Expand Down Expand Up @@ -100,6 +100,8 @@ class ExternalAppendOnlyMap[K, V, C](
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()

private var memoryOrDiskIter: Option[MemoryOrDiskIterator] = None

/**
* Insert the given key and value into the map.
*/
Expand Down Expand Up @@ -151,6 +153,30 @@ class ExternalAppendOnlyMap[K, V, C](
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
override protected[this] def spill(collection: SizeTracker): Unit = {
val it = currentMap.destructiveSortedIterator(keyComparator)
spilledMaps.append(spillMemoryToDisk(it))
}

def diskBytesSpilled: Long = _diskBytesSpilled

/**
* Return an iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
*/
override def iterator: Iterator[(K, C)] = {
shuffleMemoryManager.addSpillableToReservedList(this)
if (spilledMaps.isEmpty) {
memoryOrDiskIter = Some(MemoryOrDiskIterator(currentMap.iterator))
memoryOrDiskIter.get
} else {
new ExternalIterator()
}
}

/**
* spill contents of the in-memory map to a temporary file on disk.
*/
private[this] def spillMemoryToDisk(it: Iterator[(K, C)]): DiskMapIterator = {
val (blockId, file) = diskBlockManager.createTempLocalBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
Expand All @@ -171,7 +197,6 @@ class ExternalAppendOnlyMap[K, V, C](

var success = false
try {
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
val kv = it.next()
writer.write(kv._1, kv._2)
Expand Down Expand Up @@ -203,21 +228,52 @@ class ExternalAppendOnlyMap[K, V, C](
}
}
}

spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
new DiskMapIterator(file, blockId, batchSizes)
}

def diskBytesSpilled: Long = _diskBytesSpilled
/**
* spill contents of memory map to disk
*/
override def forceSpill(): Long = {
var freeMemory = 0L
if (memoryOrDiskIter.isDefined) {
_spillCount += 1
logSpillage(currentMap.estimateSize())

memoryOrDiskIter.get.spill()

_elementsRead = 0
_memoryBytesSpilled += currentMap.estimateSize()
freeMemory = myMemoryThreshold - initialMemoryThreshold
myMemoryThreshold = initialMemoryThreshold
}

freeMemory
}

/**
* Return an iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
* An iterator that read the elements from the in-memory iterator or the disk iterator after
* spilling contents of in-memory iterator to disk.
*/
override def iterator: Iterator[(K, C)] = {
if (spilledMaps.isEmpty) {
currentMap.iterator
} else {
new ExternalIterator()
case class MemoryOrDiskIterator(memIt: Iterator[(K,C)]) extends Iterator[(K,C)] {

var currentIt = memIt

override def hasNext: Boolean = currentIt.hasNext

override def next(): (K, C) = currentIt.next()

def spill() = {
if (hasNext) {
currentIt = spillMemoryToDisk(currentIt)
} else {
//the memory iterator is already drained, release it by giving an empty iterator
currentIt = new Iterator[(K,C)]{
override def hasNext: Boolean = false
override def next(): (K, C) = null
}
logInfo("nothing in memory iterator, do nothing")
}
}
}

Expand All @@ -232,7 +288,9 @@ class ExternalAppendOnlyMap[K, V, C](

// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
private val sortedMap = currentMap.destructiveSortedIterator(keyComparator)
memoryOrDiskIter = Some(MemoryOrDiskIterator(
currentMap.destructiveSortedIterator(keyComparator)))
private val sortedMap = memoryOrDiskIter.get
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)

inputStreams.foreach { it =>
Expand Down
Loading