diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 6b664b7a7dfd4..45f42e6bc8f99 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -73,7 +73,7 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { executorStarted.acquire() // Verify that receiver was started - assert(receiver.onStartCalled) + assert(receiver.callsRecorder.calls === Seq("onStart")) assert(executor.isReceiverStarted) assert(receiver.isStarted) assert(!receiver.isStopped()) @@ -106,19 +106,22 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { assert(executor.errors.head.eq(exception)) // Verify restarting actually stops and starts the receiver - receiver.restart("restarting", null, 600) - eventually(timeout(300.milliseconds), interval(10.milliseconds)) { - // receiver will be stopped async - assert(receiver.isStopped) - assert(receiver.onStopCalled) - } - eventually(timeout(1.second), interval(10.milliseconds)) { - // receiver will be started async - assert(receiver.onStartCalled) - assert(executor.isReceiverStarted) + executor.callsRecorder.reset() + receiver.callsRecorder.reset() + receiver.restart("restarting", null, 100) + eventually(timeout(10.seconds), interval(10.milliseconds)) { + // below verification ensures for now receiver is already restarted assert(receiver.isStarted) assert(!receiver.isStopped) assert(receiver.receiving) + + // both receiver supervisor and receiver should be stopped first, and started + assert(executor.callsRecorder.calls === Seq("onReceiverStop", "onReceiverStart")) + assert(receiver.callsRecorder.calls === Seq("onStop", "onStart")) + + // check whether the delay between stop and start is respected + assert(executor.callsRecorder.timestamps.reverse.reduceLeft { _ - _ } >= 100) + assert(receiver.callsRecorder.timestamps.reverse.reduceLeft { _ - _ } >= 100) } // Verify that stopping actually stops the thread @@ -290,6 +293,9 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { val arrayBuffers = new ArrayBuffer[ArrayBuffer[_]] val errors = new ArrayBuffer[Throwable] + // tracks calls of "onReceiverStart", "onReceiverStop" + val callsRecorder = new MethodsCallRecorder() + /** Check if all data structures are clean */ def isAllEmpty: Boolean = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && @@ -325,7 +331,15 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { errors += throwable } - override protected def onReceiverStart(): Boolean = true + override protected def onReceiverStart(): Boolean = { + callsRecorder.record() + true + } + + override protected def onReceiverStop(message: String, error: Option[Throwable]): Unit = { + callsRecorder.record() + super.onReceiverStop(message, error) + } override def createBlockGenerator( blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { @@ -363,36 +377,55 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { @volatile var otherThread: Thread = null @volatile var receiving = false - @volatile var onStartCalled = false - @volatile var onStopCalled = false + + // tracks calls of "onStart", "onStop" + @transient lazy val callsRecorder = new MethodsCallRecorder() def onStart() { otherThread = new Thread() { override def run() { receiving = true - var count = 0 - while(!isStopped()) { - if (sendData) { - store(count) - count += 1 + try { + var count = 0 + while(!isStopped()) { + if (sendData) { + store(count) + count += 1 + } + Thread.sleep(10) } - Thread.sleep(10) + } finally { + receiving = false } } } - onStartCalled = true + callsRecorder.record() otherThread.start() } def onStop() { - onStopCalled = true + callsRecorder.record() otherThread.join() } +} + +class MethodsCallRecorder { + // tracks calling methods as (timestamp, methodName) + private val records = new ArrayBuffer[(Long, String)] + + def record(): Unit = records.append((System.currentTimeMillis(), callerMethodName)) + + def reset(): Unit = records.clear() - def reset() { - receiving = false - onStartCalled = false - onStopCalled = false + def callsWithTimestamp: scala.collection.immutable.Seq[(Long, String)] = records.toList + + def calls: scala.collection.immutable.Seq[String] = records.map(_._2).toList + + def timestamps: scala.collection.immutable.Seq[Long] = records.map(_._1).toList + + private def callerMethodName: String = { + val stackTrace = new Throwable().getStackTrace + // it should return method name of two levels deeper + stackTrace(2).getMethodName } } -