Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ import java.util.Arrays
import java.util.jar.{JarEntry, JarOutputStream}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -154,4 +158,51 @@ private[spark] object TestUtils {
" @Override public String toString() { return \"" + toStringValue + "\"; }}")
createCompiledClass(className, destDir, sourceFile, classpathUrls)
}

/**
* Run some code involving jobs submitted to the given context and assert that the jobs spilled.
*/
def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = {
val spillListener = new SpillListener
sc.addSparkListener(spillListener)
body
assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not")
}

/**
* Run some code involving jobs submitted to the given context and assert that the jobs
* did not spill.
*/
def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = {
val spillListener = new SpillListener
sc.addSparkListener(spillListener)
body
assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
}

}


/**
* A [[SparkListener]] that detects whether spills have occurred in Spark jobs.
*/
private class SpillListener extends SparkListener {
private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]]
private val spilledStageIds = new mutable.HashSet[Int]

def numSpilledStages: Int = spilledStageIds.size

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
stageIdToTaskMetrics.getOrElseUpdate(
taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics
}

override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = {
val stageId = stageComplete.stageInfo.stageId
val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten
val spilled = metrics.map(_.memoryBytesSpilled).sum > 0
if (spilled) {
spilledStageIds += stageId
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ class ShuffleMemoryManager protected (
throw new SparkException(
s"Internal error: release called on $numBytes bytes but task only has $curMem")
}
taskMemory(taskAttemptId) -= numBytes
memoryManager.releaseExecutionMemory(numBytes)
if (taskMemory.contains(taskAttemptId)) {
taskMemory(taskAttemptId) -= numBytes
memoryManager.releaseExecutionMemory(numBytes)
}
memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ class ExternalAppendOnlyMap[K, V, C](
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()

/**
* Number of files this map has spilled so far.
* Exposed for testing.
*/
private[collection] def numSpills: Int = spilledMaps.size

/**
* Insert the given key and value into the map.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,15 @@ private[spark] trait Spillable[C] extends Logging {
private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager

// Initial threshold for the size of a collection before we start tracking its memory usage
// Exposed for testing
// For testing only
private[this] val initialMemoryThreshold: Long =
SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)

// Force this collection to spill when there are this many elements in memory
// For testing only
private[this] val numElementsForceSpillThreshold: Long =
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue)

// Threshold for this collection's size in bytes before we start tracking its memory usage
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
private[this] var myMemoryThreshold = initialMemoryThreshold
Expand All @@ -69,27 +74,27 @@ private[spark] trait Spillable[C] extends Logging {
* @return true if `collection` was spilled to disk; false otherwise
*/
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
myMemoryThreshold += granted
if (myMemoryThreshold <= currentMemory) {
// We were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold); spill the current collection
_spillCount += 1
logSpillage(currentMemory)

spill(collection)

_elementsRead = 0
// Keep track of spills, and release memory
_memoryBytesSpilled += currentMemory
releaseMemoryForThisThread()
return true
}
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
shouldSpill = currentMemory >= myMemoryThreshold
}
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemoryForThisThread()
}
false
shouldSpill
}

/**
Expand Down
39 changes: 26 additions & 13 deletions core/src/test/scala/org/apache/spark/DistributedSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,22 +203,35 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
}

test("compute without caching when no partitions fit in memory") {
sc = new SparkContext(clusterUrl, "test")
// data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache
// to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory
val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER)
assert(data.count() === 4000000)
assert(data.count() === 4000000)
assert(data.count() === 4000000)
val size = 10000
val conf = new SparkConf()
.set("spark.storage.unrollMemoryThreshold", "1024")
.set("spark.testing.memory", (size / 2).toString)
sc = new SparkContext(clusterUrl, "test", conf)
val data = sc.parallelize(1 to size, 2).persist(StorageLevel.MEMORY_ONLY)
assert(data.count() === size)
assert(data.count() === size)
assert(data.count() === size)
// ensure only a subset of partitions were cached
val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true)
assert(rddBlocks.size === 0, s"expected no RDD blocks, found ${rddBlocks.size}")
}

test("compute when only some partitions fit in memory") {
sc = new SparkContext(clusterUrl, "test", new SparkConf)
// TODO: verify that only a subset of partitions fit in memory (SPARK-11078)
val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER)
assert(data.count() === 4000000)
assert(data.count() === 4000000)
assert(data.count() === 4000000)
val size = 10000
val numPartitions = 10
val conf = new SparkConf()
.set("spark.storage.unrollMemoryThreshold", "1024")
.set("spark.testing.memory", (size * numPartitions).toString)
sc = new SparkContext(clusterUrl, "test", conf)
val data = sc.parallelize(1 to size, numPartitions).persist(StorageLevel.MEMORY_ONLY)
assert(data.count() === size)
assert(data.count() === size)
assert(data.count() === size)
// ensure only a subset of partitions were cached
val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true)
assert(rddBlocks.size > 0, "no RDD blocks found")
assert(rddBlocks.size < numPartitions, s"too many RDD blocks found, expected <$numPartitions")
}

test("passing environment variables to cluster") {
Expand Down
Loading