From bcc0297e8af639132bebb0703eb7df8c8fea15f9 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 24 Jun 2016 21:50:13 +0100 Subject: [PATCH 1/2] Make spill tests wait until job has completed before returning the number of stages that spilled --- core/src/main/scala/org/apache/spark/TestUtils.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 43c89b258f2f..3058606f2e48 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -22,6 +22,7 @@ import java.net.{URI, URL} import java.nio.charset.StandardCharsets import java.nio.file.Paths import java.util.Arrays +import java.util.concurrent.CountDownLatch import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConverters._ @@ -190,8 +191,12 @@ private[spark] object TestUtils { private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] private val spilledStageIds = new mutable.HashSet[Int] + private val stagesDone = new CountDownLatch(1) - def numSpilledStages: Int = spilledStageIds.size + def numSpilledStages: Int = { + stagesDone.await() + spilledStageIds.size + } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { stageIdToTaskMetrics.getOrElseUpdate( @@ -206,4 +211,8 @@ private class SpillListener extends SparkListener { spilledStageIds += stageId } } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + stagesDone.countDown() + } } From 14922800d1a98f7d2305017fcd5fa46847d669d3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 25 Jun 2016 07:42:38 +0100 Subject: [PATCH 2/2] Add timeout to new wait condition for safety --- core/src/main/scala/org/apache/spark/TestUtils.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 3058606f2e48..871b9d1ad575 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -22,7 +22,7 @@ import java.net.{URI, URL} import java.nio.charset.StandardCharsets import java.nio.file.Paths import java.util.Arrays -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConverters._ @@ -194,7 +194,9 @@ private class SpillListener extends SparkListener { private val stagesDone = new CountDownLatch(1) def numSpilledStages: Int = { - stagesDone.await() + // Long timeout, just in case somehow the job end isn't notified. + // Fails if a timeout occurs + assert(stagesDone.await(10, TimeUnit.SECONDS)) spilledStageIds.size }