Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets
import java.security.SecureRandom
import java.security.cert.X509Certificate
import java.util.{Arrays, Properties}
import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
import java.util.concurrent.{TimeoutException, TimeUnit}
import java.util.jar.{JarEntry, JarOutputStream}
import javax.net.ssl._
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
Expand Down Expand Up @@ -172,22 +172,22 @@ private[spark] object TestUtils {
/**
* Run some code involving jobs submitted to the given context and assert that the jobs spilled.
*/
def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = {
val spillListener = new SpillListener
sc.addSparkListener(spillListener)
body
assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not")
def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = {
withListener(sc, new SpillListener) { listener =>
body
assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not")
}
}

/**
* Run some code involving jobs submitted to the given context and assert that the jobs
* did not spill.
*/
def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = {
val spillListener = new SpillListener
sc.addSparkListener(spillListener)
body
assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = {
withListener(sc, new SpillListener) { listener =>
body
assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
}
}

/**
Expand Down Expand Up @@ -233,6 +233,21 @@ private[spark] object TestUtils {
}
}

/**
* Runs some code with the given listener installed in the SparkContext. After the code runs,
* this method will wait until all events posted to the listener bus are processed, and then
* remove the listener from the bus.
*/
def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private? hardly matters.

sc.addSparkListener(listener)
try {
body(listener)
} finally {
sc.listenerBus.waitUntilEmpty(TimeUnit.SECONDS.toMillis(10))
sc.listenerBus.removeListener(listener)
}
}

/**
* Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting
* time elapsed before `numExecutors` executors up. Exposed for testing.
Expand Down Expand Up @@ -289,30 +304,22 @@ private[spark] object TestUtils {
private class SpillListener extends SparkListener {
private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]]
private val spilledStageIds = new mutable.HashSet[Int]
private val stagesDone = new CountDownLatch(1)

def numSpilledStages: Int = {
// Long timeout, just in case somehow the job end isn't notified.
// Fails if a timeout occurs
assert(stagesDone.await(10, TimeUnit.SECONDS))
def numSpilledStages: Int = synchronized {
spilledStageIds.size
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
stageIdToTaskMetrics.getOrElseUpdate(
taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics
}

override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = {
override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = synchronized {
val stageId = stageComplete.stageInfo.stageId
val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten
val spilled = metrics.map(_.memoryBytesSpilled).sum > 0
if (spilled) {
spilledStageIds += stageId
}
}

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
stagesDone.countDown()
}
}