Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize")
.doc("The max number of entries to be stored in queue to wait for late epochs. " +
"If this parameter is exceeded by the size of the queue, stream will stop with an error.")
.intConf
.createWithDefault(10000)

val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
Expand Down Expand Up @@ -2041,6 +2048,9 @@ class SQLConf extends Serializable with Logging {

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

def continuousStreamingEpochBacklogQueueSize: Int =
getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE)

def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE)

def continuousStreamingExecutorPollIntervalMs: Long =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.continuous

import java.util.UUID
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
import java.util.function.UnaryOperator

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -58,6 +59,9 @@ class ContinuousExecution(
// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _

// Throwable that caused the execution to fail
private val failure: AtomicReference[Throwable] = new AtomicReference[Throwable](null)

override val logicalPlan: LogicalPlan = {
val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]()
var nextSourceId = 0
Expand Down Expand Up @@ -261,6 +265,11 @@ class ContinuousExecution(
lastExecution.toRdd
}
}

val f = failure.get()
if (f != null) {
throw f
}
} catch {
case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) &&
state.get() == RECONFIGURING =>
Expand Down Expand Up @@ -373,6 +382,35 @@ class ContinuousExecution(
}
}

/**
* Stores error and stops the query execution thread to terminate the query in new thread.
*/
def stopInNewThread(error: Throwable): Unit = {
if (failure.compareAndSet(null, error)) {
logError(s"Query $prettyIdString received exception $error")
stopInNewThread()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there is a race here. The query stop may happen before the continuous-execution checks failure and the query will just stop without any exception, just like someone stops a query manually.

}
}

/**
* Stops the query execution thread to terminate the query in new thread.
*/
private def stopInNewThread(): Unit = {
new Thread("stop-continuous-execution") {
setDaemon(true)

override def run(): Unit = {
try {
ContinuousExecution.this.stop()
} catch {
case e: Throwable =>
logError(e.getMessage, e)
throw e
}
}
}.start()
}

/**
* Stops the query execution thread to terminate the query.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ private[continuous] class EpochCoordinator(
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {

private val epochBacklogQueueSize =
session.sqlContext.conf.continuousStreamingEpochBacklogQueueSize

private var queryWritesStopped: Boolean = false

private var numReaderPartitions: Int = _
Expand Down Expand Up @@ -212,6 +215,7 @@ private[continuous] class EpochCoordinator(
if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
partitionCommits.put((epoch, partitionId), message)
resolveCommitsAtEpoch(epoch)
checkProcessingQueueBoundaries()
}

case ReportPartitionOffset(partitionId, epoch, offset) =>
Expand All @@ -223,6 +227,22 @@ private[continuous] class EpochCoordinator(
query.addOffset(epoch, stream, thisEpochOffsets.toSeq)
resolveCommitsAtEpoch(epoch)
}
checkProcessingQueueBoundaries()
}

private def checkProcessingQueueBoundaries() = {
if (partitionOffsets.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the partition offset queue has " +
"exceeded it's maximum"))
}
if (partitionCommits.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the partition commit queue has " +
"exceeded it's maximum"))
}
if (epochsWaitingToBeCommitted.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the epoch queue has " +
"exceeded it's maximum"))
}
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.TestSparkSession

Expand Down Expand Up @@ -343,3 +344,33 @@ class ContinuousMetaSuite extends ContinuousSuiteBase {
}
}
}

class ContinuousEpochBacklogSuite extends ContinuousSuiteBase {
import testImplicits._

override protected def createSparkSession = new TestSparkSession(
new SparkContext(
"local[1]",
"continuous-stream-test-sql-context",
sparkConf.set("spark.sql.testkey", "true")))

// This test forces the backlog to overflow by not standing up enough executors for the query
// to make progress.
test("epoch backlog overflow") {
withSQLConf((CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE.key, "10")) {
val df = spark.readStream
.format("rate")
.option("numPartitions", "2")
.option("rowsPerSecond", "500")
.load()
.select('value)

testStream(df, useV2Sink = true)(
StartStream(Trigger.Continuous(1)),
ExpectFailure[IllegalStateException] { e =>
e.getMessage.contains("queue has exceeded it's maximum")
}
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

package org.apache.spark.sql.streaming.continuous

import org.mockito.{ArgumentCaptor, InOrder}
import org.mockito.ArgumentMatchers.{any, eq => eqTo}
import org.mockito.InOrder
import org.mockito.Mockito.{inOrder, never, verify}
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.mockito.MockitoSugar

import org.apache.spark._
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.LocalSparkSession
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
Expand All @@ -43,14 +44,20 @@ class EpochCoordinatorSuite
private var writeSupport: StreamingWrite = _
private var query: ContinuousExecution = _
private var orderVerifier: InOrder = _
private val epochBacklogQueueSize = 10

override def beforeEach(): Unit = {
val stream = mock[ContinuousStream]
writeSupport = mock[StreamingWrite]
query = mock[ContinuousExecution]
orderVerifier = inOrder(writeSupport, query)

spark = new TestSparkSession()
spark = new TestSparkSession(
new SparkContext(
"local[2]", "test-sql-context",
new SparkConf().set("spark.sql.testkey", "true")
.set(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE.key,
epochBacklogQueueSize.toString)))

epochCoordinator
= EpochCoordinatorRef.create(writeSupport, stream, query, "test", 1, spark, SparkEnv.get)
Expand Down Expand Up @@ -186,6 +193,66 @@ class EpochCoordinatorSuite
verifyCommitsInOrderOf(List(1, 2, 3, 4, 5))
}

test("several epochs, max epoch backlog reached by partitionOffsets") {
setWriterPartitions(1)
setReaderPartitions(1)

reportPartitionOffset(0, 1)
// Commit messages not arriving
for (i <- 2 to epochBacklogQueueSize + 1) {
reportPartitionOffset(0, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 1) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the partition offset queue has exceeded it's maximum")
}

test("several epochs, max epoch backlog reached by partitionCommits") {
setWriterPartitions(1)
setReaderPartitions(1)

commitPartitionEpoch(0, 1)
// Offset messages not arriving
for (i <- 2 to epochBacklogQueueSize + 1) {
commitPartitionEpoch(0, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 1) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the partition commit queue has exceeded it's maximum")
}

test("several epochs, max epoch backlog reached by epochsWaitingToBeCommitted") {
setWriterPartitions(2)
setReaderPartitions(2)

commitPartitionEpoch(0, 1)
reportPartitionOffset(0, 1)

// For partition 2 epoch 1 messages never arriving
// +2 because the first epoch not yet arrived
for (i <- 2 to epochBacklogQueueSize + 2) {
commitPartitionEpoch(0, i)
reportPartitionOffset(0, i)
commitPartitionEpoch(1, i)
reportPartitionOffset(1, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 2) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the epoch queue has exceeded it's maximum")
}

private def setWriterPartitions(numPartitions: Int): Unit = {
epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions))
}
Expand Down Expand Up @@ -221,4 +288,13 @@ class EpochCoordinatorSuite
private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = {
epochs.foreach(verifyCommit)
}

private def verifyStoppedWithException(msg: String): Unit = {
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]);
verify(query, atLeastOnce()).stopInNewThread(exceptionCaptor.capture())

import scala.collection.JavaConverters._
val throwable = exceptionCaptor.getAllValues.asScala.find(_.getMessage === msg)
assert(throwable != null, "Stream stopped with an exception but expected message is missing")
}
}