Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@

package org.apache.spark.sql.execution.streaming

import java.io.{InterruptedIOException, IOException}
import java.io.{InterruptedIOException, IOException, UncheckedIOException}
import java.nio.channels.ClosedByInterruptException
import java.util.UUID
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.locks.ReentrantLock

import scala.collection.mutable.{Map => MutableMap}
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import com.google.common.util.concurrent.UncheckedExecutionException
import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -335,7 +337,7 @@ class StreamExecution(
// `stop()` is already called. Let `finally` finish the cleanup.
}
} catch {
case _: InterruptedException | _: InterruptedIOException if state.get == TERMINATED =>
case e if isInterruptedByStop(e) =>
// interrupted by stop()
updateStatusMessage("Stopped")
case e: IOException if e.getMessage != null
Expand Down Expand Up @@ -407,6 +409,32 @@ class StreamExecution(
}
}

private def isInterruptedByStop(e: Throwable): Boolean = {
if (state.get == TERMINATED) {
e match {
// InterruptedIOException - thrown when an I/O operation is interrupted
// ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted
case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException =>
true
// The cause of the following exceptions may be one of the above exceptions:
//
// UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as
// BiFunction.apply
// ExecutionException - thrown by codes running in a thread pool and these codes throw an
// exception
// UncheckedExecutionException - thrown by codes that cannot throw a checked
// ExecutionException, such as BiFunction.apply
case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException)
if e2.getCause != null =>
isInterruptedByStop(e2.getCause)
case _ =>
false
}
} else {
false
}
}

/**
* Populate the start offsets to start the execution at the current offsets stored in the sink
* (i.e. avoid reprocessing data that we have already processed). This function must be called
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.sql.streaming

import java.io.{File, InterruptedIOException, IOException}
import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
import java.io.{File, InterruptedIOException, IOException, UncheckedIOException}
import java.nio.channels.ClosedByInterruptException
import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit}

import scala.reflect.ClassTag
import scala.util.control.ControlThrowable

import com.google.common.util.concurrent.UncheckedExecutionException
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration

Expand Down Expand Up @@ -690,6 +692,31 @@ class StreamSuite extends StreamTest {
}
}
}

for (e <- Seq(
new InterruptedException,
new InterruptedIOException,
new ClosedByInterruptException,
new UncheckedIOException("test", new ClosedByInterruptException),
new ExecutionException("test", new InterruptedException),
new UncheckedExecutionException("test", new InterruptedException))) {
test(s"view ${e.getClass.getSimpleName} as a normal query stop") {
ThrowingExceptionInCreateSource.createSourceLatch = new CountDownLatch(1)
ThrowingExceptionInCreateSource.exception = e
val query = spark
.readStream
.format(classOf[ThrowingExceptionInCreateSource].getName)
.load()
.writeStream
.format("console")
.start()
assert(ThrowingExceptionInCreateSource.createSourceLatch
.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS),
"ThrowingExceptionInCreateSource.createSource wasn't called before timeout")
query.stop()
assert(query.exception.isEmpty)
}
}
}

abstract class FakeSource extends StreamSourceProvider {
Expand Down Expand Up @@ -814,3 +841,32 @@ class TestStateStoreProvider extends StateStoreProvider {

override def getStore(version: Long): StateStore = null
}

/** A fake source that throws `ThrowingExceptionInCreateSource.exception` in `createSource` */
class ThrowingExceptionInCreateSource extends FakeSource {

override def createSource(
spark: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
ThrowingExceptionInCreateSource.createSourceLatch.countDown()
try {
Thread.sleep(30000)
throw new TimeoutException("sleep was not interrupted in 30 seconds")
} catch {
case _: InterruptedException =>
throw ThrowingExceptionInCreateSource.exception
}
}
}

object ThrowingExceptionInCreateSource {
/**
* A latch to allow the user to wait until `ThrowingExceptionInCreateSource.createSource` is
* called.
*/
@volatile var createSourceLatch: CountDownLatch = null
@volatile var exception: Exception = null
}