Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix and review
  • Loading branch information
grundprinzip committed Nov 6, 2023
commit b2ee6acf95daa4082b303d2bcdac0b5b643d1db2
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,56 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
accumulator = null)
}

test("python listener process: process terminates after listener is removed") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
// scalastyle:on assume

val sessionHolder = SessionHolder.forTesting(spark)
try {
SparkConnectService.start(spark.sparkContext)

val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)

val id1 = "listener_removeListener_test_1"
val id2 = "listener_removeListener_test_2"
val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)

sessionHolder.cacheListenerById(id1, listener1)
spark.streams.addListener(listener1)
sessionHolder.cacheListenerById(id2, listener2)
spark.streams.addListener(listener2)

val (runner1, runner2) = (listener1.runner, listener2.runner)

// assert both python processes are running
assert(!runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)

// remove listener1
spark.streams.removeListener(listener1)
sessionHolder.removeCachedListener(id1)
// assert listener1's python process is not running
eventually(timeout(30.seconds)) {
assert(runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)
}

// remove listener2
spark.streams.removeListener(listener2)
sessionHolder.removeCachedListener(id2)
eventually(timeout(30.seconds)) {
// assert listener2's python process is not running
assert(runner2.isWorkerStopped().get)
// all listeners are removed
assert(spark.streams.listListeners().isEmpty)
}
} finally {
SparkConnectService.stop()
}
}

test("python foreachBatch process: process terminates after query is stopped") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
Expand Down Expand Up @@ -232,58 +282,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
assert(spark.streams.listListeners().length == 1) // only process termination listener
} finally {
SparkConnectService.stop()
// Wait for things to calm down.
Thread.sleep(4.seconds.toMillis)
// remove process termination listener
spark.streams.listListeners().foreach(spark.streams.removeListener)
}
}

test("python listener process: process terminates after listener is removed") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
// scalastyle:on assume

val sessionHolder = SessionHolder.forTesting(spark)
try {
SparkConnectService.start(spark.sparkContext)

val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)

val id1 = "listener_removeListener_test_1"
val id2 = "listener_removeListener_test_2"
val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)

sessionHolder.cacheListenerById(id1, listener1)
spark.streams.addListener(listener1)
sessionHolder.cacheListenerById(id2, listener2)
spark.streams.addListener(listener2)

val (runner1, runner2) = (listener1.runner, listener2.runner)

// assert both python processes are running
assert(!runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)

// remove listener1
spark.streams.removeListener(listener1)
sessionHolder.removeCachedListener(id1)
// assert listener1's python process is not running
eventually(timeout(30.seconds)) {
assert(runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)
}

// remove listener2
spark.streams.removeListener(listener2)
sessionHolder.removeCachedListener(id2)
eventually(timeout(30.seconds)) {
// assert listener2's python process is not running
assert(runner2.isWorkerStopped().get)
// all listeners are removed
assert(spark.streams.listListeners().isEmpty)
}
} finally {
SparkConnectService.stop()
}
}
}
70 changes: 35 additions & 35 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
import pyspark.sql.connect.proto as pb2
import json
from typing import Dict, List, Optional, TYPE_CHECKING, overload
from typing import Dict, List, Optional, TYPE_CHECKING

from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
Expand Down Expand Up @@ -49,7 +49,7 @@ def convert_exception(
info: "ErrorInfo",
truncated_message: str,
resp: Optional[pb2.FetchErrorDetailsResponse],
display_stacktrace: bool = False,
display_server_stacktrace: bool = False,
) -> SparkConnectException:
classes = []
sql_state = None
Expand All @@ -70,105 +70,105 @@ def convert_exception(
else:
message = truncated_message
stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else None
display_stacktrace = display_stacktrace if stacktrace is not None else False
display_server_stacktrace = display_server_stacktrace if stacktrace is not None else False

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
# Order matters. ParseException inherits AnalysisException.
elif "org.apache.spark.sql.AnalysisException" in classes:
return AnalysisException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
return StreamingQueryException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
return QueryExecutionException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
# Order matters. NumberFormatException inherits IllegalArgumentException.
elif "java.lang.NumberFormatException" in classes:
return NumberFormatException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "java.lang.IllegalArgumentException" in classes:
return IllegalArgumentException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "java.lang.ArithmeticException" in classes:
return ArithmeticException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "java.lang.UnsupportedOperationException" in classes:
return UnsupportedOperationException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
return ArrayIndexOutOfBoundsException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "java.time.DateTimeException" in classes:
return DateTimeException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "org.apache.spark.SparkRuntimeException" in classes:
return SparkRuntimeException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "org.apache.spark.SparkUpgradeException" in classes:
return SparkUpgradeException(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
elif "org.apache.spark.api.python.PythonException" in classes:
return PythonException(
Expand All @@ -181,17 +181,17 @@ def convert_exception(
message,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)
else:
return SparkConnectGrpcException(
message,
reason=info.reason,
error_class=error_class,
sql_state=sql_state,
stacktrace=stacktrace,
display_stacktrace=display_stacktrace,
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
)


Expand Down Expand Up @@ -234,8 +234,8 @@ def __init__(
message_parameters: Optional[Dict[str, str]] = None,
reason: Optional[str] = None,
sql_state: Optional[str] = None,
stacktrace: Optional[str] = None,
display_stacktrace: bool = False,
server_stacktrace: Optional[str] = None,
display_server_stacktrace: bool = False,
) -> None:
self.message = message # type: ignore[assignment]
if reason is not None:
Expand All @@ -258,8 +258,8 @@ def __init__(
)
self.error_class = error_class
self._sql_state: Optional[str] = sql_state
self._stacktrace: Optional[str] = stacktrace
self._display_stacktrace: bool = display_stacktrace
self._stacktrace: Optional[str] = server_stacktrace
self._display_stacktrace: bool = display_server_stacktrace

def getSqlState(self) -> None:
if self._sql_state is not None:
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
except grpc.RpcError:
return None

def _display_stack_trace(self) -> bool:
def _display_server_stack_trace(self) -> bool:
from pyspark.sql.connect.conf import RuntimeConf

conf = RuntimeConf(self)
Expand Down Expand Up @@ -1605,7 +1605,7 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
info,
status.message,
self._fetch_enriched_error(info),
self._display_stack_trace(),
self._display_server_stack_trace(),
) from None

raise SparkConnectGrpcException(status.message) from None
Expand Down