diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index b5c4c705dcbc..6cc8fe1173d2 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate import java.util.{Arrays, Properties} -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -172,22 +172,22 @@ private[spark] object TestUtils { /** * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ - def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { + withListener(sc, new SpillListener) { listener => + body + assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + } } /** * Run some code involving jobs submitted to the given context and assert that the jobs * did not spill. */ - def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { + withListener(sc, new SpillListener) { listener => + body + assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + } } /** @@ -233,6 +233,21 @@ private[spark] object TestUtils { } } + /** + * Runs some code with the given listener installed in the SparkContext. After the code runs, + * this method will wait until all events posted to the listener bus are processed, and then + * remove the listener from the bus. + */ + def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = { + sc.addSparkListener(listener) + try { + body(listener) + } finally { + sc.listenerBus.waitUntilEmpty(TimeUnit.SECONDS.toMillis(10)) + sc.listenerBus.removeListener(listener) + } + } + /** * Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting * time elapsed before `numExecutors` executors up. Exposed for testing. @@ -289,21 +304,17 @@ private[spark] object TestUtils { private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] private val spilledStageIds = new mutable.HashSet[Int] - private val stagesDone = new CountDownLatch(1) - def numSpilledStages: Int = { - // Long timeout, just in case somehow the job end isn't notified. - // Fails if a timeout occurs - assert(stagesDone.await(10, TimeUnit.SECONDS)) + def numSpilledStages: Int = synchronized { spilledStageIds.size } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { stageIdToTaskMetrics.getOrElseUpdate( taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics } - override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = { + override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = synchronized { val stageId = stageComplete.stageInfo.stageId val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten val spilled = metrics.map(_.memoryBytesSpilled).sum > 0 @@ -311,8 +322,4 @@ private class SpillListener extends SparkListener { spilledStageIds += stageId } } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - stagesDone.countDown() - } }