Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.sql.internal.SQLConf

/**
* Handles [[proto.FetchErrorDetailsRequest]]s for the [[SparkConnectService]]. The handler
Expand All @@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler(

ErrorUtils.throwableToFetchErrorDetailsResponse(
st = error,
serverStackTraceEnabled = sessionHolder.session.conf.get(
Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || sessionHolder.session.conf.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we deprecating this config Connect.CONNECT_SERVER_STACKTRACE_ENABLED?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still used, but only verifies the display behavior rather than the stack trace generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also make Connect.CONNECT_SERVER_STACKTRACE_ENABLED work for Scala client in this pr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit more weird, in contrast to Python, the server backtrace is always there in Scala, but the user can decide how to print it.

SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
serverStackTraceEnabled = true)
}
.getOrElse(FetchErrorDetailsResponse.newBuilder().build())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging {
"classes",
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))

// Add the SQL State and Error Class to the response metadata of the ErrorInfoObject.
st match {
case e: SparkThrowable =>
val state = e.getSqlState
if (state != null && state.nonEmpty) {
errorInfo.putMetadata("sqlState", state)
}
val errorClass = e.getErrorClass
if (errorClass != null && errorClass.nonEmpty) {
errorInfo.putMetadata("errorClass", errorClass)
}
case _ =>
}

if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) {
// Generate a new unique key for this exception.
val errorId = UUID.randomUUID().toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,11 @@ class FetchErrorDetailsHandlerSuite extends SharedSparkSession with ResourceHelp
assert(response.getErrors(1).getErrorTypeHierarchy(1) == classOf[Throwable].getName)
assert(response.getErrors(1).getErrorTypeHierarchy(2) == classOf[Object].getName)
assert(!response.getErrors(1).hasCauseIdx)
if (serverStacktraceEnabled) {
assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length)
assert(
response.getErrors(1).getStackTraceCount ==
testError.getCause.getStackTrace.length)
} else {
assert(response.getErrors(0).getStackTraceCount == 0)
assert(response.getErrors(1).getStackTraceCount == 0)
}
assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length)
assert(
response.getErrors(1).getStackTraceCount ==
testError.getCause.getStackTrace.length)

} finally {
sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key)
sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,56 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
accumulator = null)
}

test("python listener process: process terminates after listener is removed") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
// scalastyle:on assume

val sessionHolder = SessionHolder.forTesting(spark)
try {
SparkConnectService.start(spark.sparkContext)

val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)

val id1 = "listener_removeListener_test_1"
val id2 = "listener_removeListener_test_2"
val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)

sessionHolder.cacheListenerById(id1, listener1)
spark.streams.addListener(listener1)
sessionHolder.cacheListenerById(id2, listener2)
spark.streams.addListener(listener2)

val (runner1, runner2) = (listener1.runner, listener2.runner)

// assert both python processes are running
assert(!runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)

// remove listener1
spark.streams.removeListener(listener1)
sessionHolder.removeCachedListener(id1)
// assert listener1's python process is not running
eventually(timeout(30.seconds)) {
assert(runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)
}

// remove listener2
spark.streams.removeListener(listener2)
sessionHolder.removeCachedListener(id2)
eventually(timeout(30.seconds)) {
// assert listener2's python process is not running
assert(runner2.isWorkerStopped().get)
// all listeners are removed
assert(spark.streams.listListeners().isEmpty)
}
} finally {
SparkConnectService.stop()
}
}

test("python foreachBatch process: process terminates after query is stopped") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
Expand Down Expand Up @@ -232,58 +282,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
assert(spark.streams.listListeners().length == 1) // only process termination listener
} finally {
SparkConnectService.stop()
// Wait for things to calm down.
Thread.sleep(4.seconds.toMillis)
// remove process termination listener
spark.streams.listListeners().foreach(spark.streams.removeListener)
}
}

test("python listener process: process terminates after listener is removed") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
// scalastyle:on assume

val sessionHolder = SessionHolder.forTesting(spark)
try {
SparkConnectService.start(spark.sparkContext)

val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)

val id1 = "listener_removeListener_test_1"
val id2 = "listener_removeListener_test_2"
val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)

sessionHolder.cacheListenerById(id1, listener1)
spark.streams.addListener(listener1)
sessionHolder.cacheListenerById(id2, listener2)
spark.streams.addListener(listener2)

val (runner1, runner2) = (listener1.runner, listener2.runner)

// assert both python processes are running
assert(!runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)

// remove listener1
spark.streams.removeListener(listener1)
sessionHolder.removeCachedListener(id1)
// assert listener1's python process is not running
eventually(timeout(30.seconds)) {
assert(runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)
}

// remove listener2
spark.streams.removeListener(listener2)
sessionHolder.removeCachedListener(id2)
eventually(timeout(30.seconds)) {
// assert listener2's python process is not running
assert(runner2.isWorkerStopped().get)
// all listeners are removed
assert(spark.streams.listListeners().isEmpty)
}
} finally {
SparkConnectService.stop()
}
}
}
Loading