Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
val streamingQueryStartedEventCache
: ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap()

def isServerSideListenerRegistered: Boolean = streamingQueryServerSideListener.isDefined
val lock = new Object()

def isServerSideListenerRegistered: Boolean = lock.synchronized {
streamingQueryServerSideListener.isDefined
}

/**
* The initialization of the server side listener and related resources. This method is called
Expand All @@ -62,7 +66,7 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
* @param responseObserver
* the responseObserver created from the first long running executeThread.
*/
def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = lock.synchronized {
val serverListener = new SparkConnectListenerBusListener(this, responseObserver)
sessionHolder.session.streams.addListener(serverListener)
streamingQueryServerSideListener = Some(serverListener)
Expand All @@ -76,7 +80,7 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
* the latch, so the long-running thread can proceed to send back the final ResultComplete
* response.
*/
def cleanUp(): Unit = {
def cleanUp(): Unit = lock.synchronized {
streamingQueryServerSideListener.foreach { listener =>
sessionHolder.session.streams.removeListener(listener)
}
Expand Down Expand Up @@ -106,18 +110,18 @@ private[sql] class SparkConnectListenerBusListener(
// all related sources are cleaned up, and the long-running thread will proceed to send
// the final ResultComplete response.
private def send(eventJson: String, eventType: StreamingQueryEventType): Unit = {
val event = StreamingQueryListenerEvent
.newBuilder()
.setEventJson(eventJson)
.setEventType(eventType)
.build()
try {
val event = StreamingQueryListenerEvent
.newBuilder()
.setEventJson(eventJson)
.setEventType(eventType)
.build()

val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
val eventResult = respBuilder
.addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava)
.build()
val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
val eventResult = respBuilder
.addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava)
.build()

try {
responseObserver.onNext(
ExecutePlanResponse
.newBuilder()
Expand All @@ -143,14 +147,24 @@ private[sql] class SparkConnectListenerBusListener(
}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
logDebug(
s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " +
s"Sending QueryProgressEvent to client, id: ${event.progress.id}" +
s" runId: ${event.progress.runId}, batch: ${event.progress.batchId}.")
send(event.json, StreamingQueryEventType.QUERY_PROGRESS_EVENT)
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
logDebug(
s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " +
s"Sending QueryTerminatedEvent to client, id: ${event.id} runId: ${event.runId}.")
send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT)
}

override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = {
logDebug(
s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " +
s"Sending QueryIdleEvent to client, id: ${event.id} runId: ${event.runId}.")
send(event.json, StreamingQueryEventType.QUERY_IDLE_EVENT)
}
}
25 changes: 25 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,28 @@ def execute_command(
else:
return (None, properties)

def execute_command_as_iterator(
self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None
) -> Iterator[Dict[str, Any]]:
"""
Execute given command. Similar to execute_command, but the value is returned using yield.
"""
logger.info(f"Execute command as iterator for command {self._proto_to_string(command)}")
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
for response in self._execute_and_fetch_as_iterator(req, observations or {}):
if isinstance(response, dict):
yield response
else:
raise PySparkValueError(
error_class="UNKNOWN_RESPONSE",
message_parameters={
"response": str(response),
},
)

def same_semantics(self, plan: pb2.Plan, other: pb2.Plan) -> bool:
"""
return if two plans have the same semantics.
Expand Down Expand Up @@ -1330,6 +1352,9 @@ def handle_response(
if b.HasField("streaming_query_manager_command_result"):
cmd_result = b.streaming_query_manager_command_result
yield {"streaming_query_manager_command_result": cmd_result}
if b.HasField("streaming_query_listener_events_result"):
event_result = b.streaming_query_listener_events_result
yield {"streaming_query_listener_events_result": event_result}
if b.HasField("get_resources_command_result"):
resources = {}
for key, resource in b.get_resources_command_result.resources.items():
Expand Down
196 changes: 174 additions & 22 deletions python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,27 @@

import json
import sys
import pickle
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
import warnings
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional, Union, Iterator
from threading import Thread, Lock

from pyspark.errors import StreamingQueryException, PySparkValueError
import pyspark.sql.connect.proto as pb2
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.connect import proto
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.streaming import StreamingQueryListener
from pyspark.sql.streaming.listener import (
QueryStartedEvent,
QueryProgressEvent,
QueryIdleEvent,
QueryTerminatedEvent,
)
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
StreamingQueryManager as PySparkStreamingQueryManager,
)
from pyspark.errors.exceptions.connect import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.errors import PySparkPicklingError

if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession
Expand Down Expand Up @@ -184,6 +188,7 @@ def _execute_streaming_query_cmd(
class StreamingQueryManager:
def __init__(self, session: "SparkSession") -> None:
self._session = session
self._sqlb = StreamingQueryListenerBus(self)

@property
def active(self) -> List[StreamingQuery]:
Expand Down Expand Up @@ -237,27 +242,13 @@ def resetTerminated(self) -> None:
resetTerminated.__doc__ = PySparkStreamingQueryManager.resetTerminated.__doc__

def addListener(self, listener: StreamingQueryListener) -> None:
listener._init_listener_id()
cmd = pb2.StreamingQueryManagerCommand()
expr = proto.PythonUDF()
try:
expr.command = CloudPickleSerializer().dumps(listener)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "addListener"},
)
expr.python_ver = get_python_ver()
cmd.add_listener.python_listener_payload.CopyFrom(expr)
cmd.add_listener.id = listener._id
self._execute_streaming_query_manager_cmd(cmd)
listener._set_spark_session(self._session)
self._sqlb.append(listener)

addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__

def removeListener(self, listener: StreamingQueryListener) -> None:
cmd = pb2.StreamingQueryManagerCommand()
cmd.remove_listener.id = listener._id
self._execute_streaming_query_manager_cmd(cmd)
self._sqlb.remove(listener)

removeListener.__doc__ = PySparkStreamingQueryManager.removeListener.__doc__

Expand All @@ -273,6 +264,167 @@ def _execute_streaming_query_manager_cmd(
)


class StreamingQueryListenerBus:
"""
A client side listener bus that is responsible for buffering client side listeners,
receive listener events and invoke correct listener call backs.
"""

def __init__(self, sqm: "StreamingQueryManager") -> None:
self._sqm = sqm
self._listener_bus: List[StreamingQueryListener] = []
self._execution_thread: Optional[Thread] = None
self._lock = Lock()

def append(self, listener: StreamingQueryListener) -> None:
"""
Append a listener to the local listener bus. When the added listener is
the first listener, request the server to create the server side listener
and start a thread to handle query events.
"""
with self._lock:
self._listener_bus.append(listener)

if len(self._listener_bus) == 1:
assert self._execution_thread is None
try:
result_iter = self._register_server_side_listener()
except Exception as e:
warnings.warn(
f"Failed to add the listener because of exception: {e}\n"
f"The listener is not added, please add it again."
)
self._listener_bus.remove(listener)
return
self._execution_thread = Thread(
target=self._query_event_handler, args=(result_iter,)
)
self._execution_thread.start()

def remove(self, listener: StreamingQueryListener) -> None:
"""
Remove the listener from the local listener bus.

When the listener is not presented in the listener bus, do nothing.

When the removed listener is the last listener, ask the server to remove
the server side listener.
As a result, the listener handling thread created before
will return after processing remaining listener events. This function blocks until
all events are processed.
"""
with self._lock:
if listener not in self._listener_bus:
return

if len(self._listener_bus) == 1:
cmd = pb2.StreamingQueryListenerBusCommand()
cmd.remove_listener_bus_listener = True
exec_cmd = pb2.Command()
exec_cmd.streaming_query_listener_bus_command.CopyFrom(cmd)
try:
self._sqm._session.client.execute_command(exec_cmd)
except Exception as e:
warnings.warn(
f"Failed to remove the listener because of exception: {e}\n"
f"The listener is not removed, please remove it again."
)
return
if self._execution_thread is not None:
self._execution_thread.join()
self._execution_thread = None

self._listener_bus.remove(listener)

def _register_server_side_listener(self) -> Iterator[Dict[str, Any]]:
"""
Send add listener request to the server, after received confirmation from the server,
start a new thread to handle these events.
"""
cmd = pb2.StreamingQueryListenerBusCommand()
cmd.add_listener_bus_listener = True
exec_cmd = pb2.Command()
exec_cmd.streaming_query_listener_bus_command.CopyFrom(cmd)
result_iter = self._sqm._session.client.execute_command_as_iterator(exec_cmd)
# Main thread should block until received listener_added_success message
for result in result_iter:
response = cast(
pb2.StreamingQueryListenerEventsResult,
result["streaming_query_listener_events_result"],
)
if response.HasField("listener_bus_listener_added"):
break
return result_iter

def _query_event_handler(self, iter: Iterator[Dict[str, Any]]) -> None:
"""
Handler function passed to the new thread, if there is any error while receiving
listener events, it means the connection is unstable. In this case, remove all listeners
and tell the user to add back the listeners.
"""
try:
for result in iter:
response = cast(
pb2.StreamingQueryListenerEventsResult,
result["streaming_query_listener_events_result"],
)
for event in response.events:
deserialized_event = self.deserialize(event)
self.post_to_all(deserialized_event)

except Exception as e:
warnings.warn(
"StreamingQueryListenerBus Handler thread received exception, all client side "
f"listeners are removed and handler thread is terminated. The error is: {e}"
)
with self._lock:
self._execution_thread = None
self._listener_bus.clear()
return

@staticmethod
def deserialize(
event: pb2.StreamingQueryListenerEvent,
) -> Union["QueryProgressEvent", "QueryIdleEvent", "QueryTerminatedEvent"]:
if event.event_type == proto.StreamingQueryEventType.QUERY_PROGRESS_EVENT:
return QueryProgressEvent.fromJson(json.loads(event.event_json))
elif event.event_type == proto.StreamingQueryEventType.QUERY_TERMINATED_EVENT:
return QueryTerminatedEvent.fromJson(json.loads(event.event_json))
elif event.event_type == proto.StreamingQueryEventType.QUERY_IDLE_EVENT:
return QueryIdleEvent.fromJson(json.loads(event.event_json))
else:
raise PySparkValueError(
error_class="UNKNOWN_VALUE_FOR",
message_parameters={"var": f"proto.StreamingQueryEventType: {event.event_type}"},
)

def post_to_all(
self,
event: Union[
"QueryStartedEvent", "QueryProgressEvent", "QueryIdleEvent", "QueryTerminatedEvent"
],
) -> None:
"""
Post listener events to all active listeners, note that if one listener throws,
it should not affect other listeners.
"""
with self._lock:
for listener in self._listener_bus:
try:
if isinstance(event, QueryStartedEvent):
listener.onQueryStarted(event)
elif isinstance(event, QueryProgressEvent):
listener.onQueryProgress(event)
elif isinstance(event, QueryIdleEvent):
listener.onQueryIdle(event)
elif isinstance(event, QueryTerminatedEvent):
listener.onQueryTerminated(event)
else:
warnings.warn(f"Unknown StreamingQueryListener event: {event}")
except Exception as e:
warnings.warn(f"Listener {str(listener)} threw an exception\n{e}")


def _test() -> None:
import doctest
import os
Expand Down
Loading