Skip to content
9 changes: 5 additions & 4 deletions python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
QueryProgressEvent,
QueryIdleEvent,
QueryTerminatedEvent,
StreamingQueryProgress,
)
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
Expand Down Expand Up @@ -110,21 +111,21 @@ def status(self) -> Dict[str, Any]:
status.__doc__ = PySparkStreamingQuery.status.__doc__

@property
def recentProgress(self) -> List[Dict[str, Any]]:
def recentProgress(self) -> List[StreamingQueryProgress]:
cmd = pb2.StreamingQueryCommand()
cmd.recent_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
return [json.loads(p) for p in progress]
return [StreamingQueryProgress.fromJson(json.loads(p)) for p in progress]

recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__

@property
def lastProgress(self) -> Optional[Dict[str, Any]]:
def lastProgress(self) -> Optional[StreamingQueryProgress]:
cmd = pb2.StreamingQueryCommand()
cmd.last_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
if len(progress) > 0:
return json.loads(progress[-1])
return StreamingQueryProgress.fromJson(json.loads(progress[-1]))
else:
return None

Expand Down
89 changes: 79 additions & 10 deletions python/pyspark/sql/streaming/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,13 @@ def errorClassOnException(self) -> Optional[str]:
return self._errorClassOnException


class StreamingQueryProgress:
class StreamingQueryProgress(dict):
"""
.. versionadded:: 3.4.0

.. versionchanged:: 4.0.0
Becomes a subclass of dict

Notes
-----
This API is evolving.
Expand Down Expand Up @@ -489,9 +492,11 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]],
sources=[SourceProgress.fromJson(s) for s in j["sources"]],
sink=SinkProgress.fromJson(j["sink"]),
numInputRows=j["numInputRows"],
inputRowsPerSecond=j["inputRowsPerSecond"],
processedRowsPerSecond=j["processedRowsPerSecond"],
numInputRows=j["numInputRows"] if "numInputRows" in j else None,
inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" in j else None,
processedRowsPerSecond=j["processedRowsPerSecond"]
if "processedRowsPerSecond" in j
else None,
observedMetrics={
k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows
for k, row_dict in j["observedMetrics"].items()
Expand All @@ -500,6 +505,19 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
else {},
)

def __getitem__(self, key: str) -> Any:
# Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId
# to string. But here they are UUID.
# To prevent breaking change, also cast them to string when accessed with __getitem__.
if key == "id" or key == "runId":
Copy link
Contributor Author

@WweiL WweiL Jun 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is really needed. But if we delete this if, now "query.lastProgress["id"]" would return type uuid, before it was string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there would be lots of breaking changes (e.g. now the sources method also return the actual SourceProgress

def sources(self) -> List["SourceProgress"]:

let me also make these subclass of dict...

return str(getattr(self, key))
else:
return getattr(self, key)

def __setitem__(self, key: str, value: Any) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the fixes of __getitem__ and __setitem__ but we do self.update(dict(id=id, runId=runId, ...)) at __init__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sure let me do that

internal_key = "_" + key
setattr(self, internal_key, value)

@property
def id(self) -> uuid.UUID:
"""
Expand Down Expand Up @@ -603,21 +621,30 @@ def numInputRows(self) -> int:
"""
The aggregate (across all sources) number of records processed in a trigger.
"""
return self._numInputRows
if self._numInputRows is not None:
return self._numInputRows
else:
return sum(s.numInputRows for s in self.sources)

@property
def inputRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate of data arriving.
"""
return self._inputRowsPerSecond
if self._inputRowsPerSecond is not None:
return self._inputRowsPerSecond
else:
return sum(s.inputRowsPerSecond for s in self.sources)

@property
def processedRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate at which Spark is processing data.
"""
return self._processedRowsPerSecond
if self._processedRowsPerSecond is not None:
return self._processedRowsPerSecond
else:
return sum(s.processedRowsPerSecond for s in self.sources)

@property
def json(self) -> str:
Expand All @@ -644,11 +671,17 @@ def prettyJson(self) -> str:
def __str__(self) -> str:
return self.prettyJson

def __repr__(self) -> str:
return self.prettyJson


class StateOperatorProgress:
class StateOperatorProgress(dict):
"""
.. versionadded:: 3.4.0

.. versionchanged:: 4.0.0
Becomes a subclass of dict

Notes
-----
This API is evolving.
Expand Down Expand Up @@ -795,11 +828,24 @@ def prettyJson(self) -> str:
def __str__(self) -> str:
return self.prettyJson

def __repr__(self) -> str:
return self.prettyJson

def __getitem__(self, key: str) -> Any:
return getattr(self, key)

def __setitem__(self, key: str, value: Any) -> None:
internal_key = "_" + key
setattr(self, internal_key, value)

class SourceProgress:

class SourceProgress(dict):
"""
.. versionadded:: 3.4.0

.. versionchanged:: 4.0.0
Becomes a subclass of dict

Notes
-----
This API is evolving.
Expand Down Expand Up @@ -935,11 +981,24 @@ def prettyJson(self) -> str:
def __str__(self) -> str:
return self.prettyJson

def __repr__(self) -> str:
return self.prettyJson

def __getitem__(self, key: str) -> Any:
return getattr(self, key)

class SinkProgress:
def __setitem__(self, key: str, value: Any) -> None:
internal_key = "_" + key
setattr(self, internal_key, value)


class SinkProgress(dict):
"""
.. versionadded:: 3.4.0

.. versionchanged:: 4.0.0
Becomes a subclass of dict

Notes
-----
This API is evolving.
Expand Down Expand Up @@ -1021,6 +1080,16 @@ def prettyJson(self) -> str:
def __str__(self) -> str:
return self.prettyJson

def __repr__(self) -> str:
return self.prettyJson

def __getitem__(self, key: str) -> Any:
return getattr(self, key)

def __setitem__(self, key: str, value: Any) -> None:
internal_key = "_" + key
setattr(self, internal_key, value)


def _test() -> None:
import sys
Expand Down
13 changes: 8 additions & 5 deletions python/pyspark/sql/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from pyspark.errors.exceptions.captured import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.sql.streaming.listener import StreamingQueryListener
from pyspark.sql.streaming.listener import (
StreamingQueryListener,
StreamingQueryProgress,
)

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
Expand Down Expand Up @@ -251,7 +254,7 @@ def status(self) -> Dict[str, Any]:
return json.loads(self._jsq.status().json())

@property
def recentProgress(self) -> List[Dict[str, Any]]:
def recentProgress(self) -> List[StreamingQueryProgress]:
"""
Returns an array of the most recent [[StreamingQueryProgress]] updates for this query.
The number of progress updates retained for each stream is configured by Spark session
Expand Down Expand Up @@ -280,10 +283,10 @@ def recentProgress(self) -> List[Dict[str, Any]]:

>>> sq.stop()
"""
return [json.loads(p.json()) for p in self._jsq.recentProgress()]
return [StreamingQueryProgress.fromJObject(p) for p in self._jsq.recentProgress()]

@property
def lastProgress(self) -> Optional[Dict[str, Any]]:
def lastProgress(self) -> Optional[StreamingQueryProgress]:
"""
Returns the most recent :class:`StreamingQueryProgress` update of this streaming query or
None if there were no progress updates
Expand Down Expand Up @@ -311,7 +314,7 @@ def lastProgress(self) -> Optional[Dict[str, Any]]:
"""
lastProgress = self._jsq.lastProgress()
if lastProgress:
return json.loads(lastProgress.json())
return StreamingQueryProgress.fromJObject(lastProgress)
else:
return None

Expand Down
37 changes: 36 additions & 1 deletion python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

class StreamingTestsMixin:
def test_streaming_query_functions_basic(self):
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = (
df.writeStream.format("memory")
.queryName("test_streaming_query_functions_basic")
Expand Down Expand Up @@ -59,6 +59,41 @@ def test_streaming_query_functions_basic(self):
finally:
query.stop()

def test_streaming_progress(self):
"""
Should be able to access fields using attributes in lastProgress / recentProgress
e.g. q.lastProgress.id
"""
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = df.writeStream.format("noop").start()
try:
query.processAllAvailable()
recentProgress = query.recentProgress
lastProgress = query.lastProgress
self.assertEqual(lastProgress["name"], query.name)
self.assertEqual(lastProgress["id"], query.id)
# SPARK-48567 Use attribute to access fields in q.lastProgress
self.assertEqual(lastProgress.name, query.name)
self.assertEqual(str(lastProgress.id), query.id)
new_name = "myNewQuery"
lastProgress["name"] = new_name
self.assertEqual(lastProgress.name, new_name)
self.assertTrue(any(p == lastProgress for p in recentProgress))
self.assertTrue(lastProgress.numInputRows > 0)
# Also access source / sink progress with attributes
self.assertTrue(len(lastProgress.sources) > 0)
self.assertTrue(lastProgress.sources[0].numInputRows > 0)
self.assertTrue(lastProgress["sources"][0]["numInputRows"] > 0)
self.assertTrue(lastProgress.sink.numOutputRows > 0)

except Exception as e:
self.fail(
"Streaming query functions sanity check shouldn't throw any error. "
"Error message: " + str(e)
)
finally:
query.stop()

def test_streaming_query_name_edge_case(self):
# Query name should be None when not specified
q1 = self.spark.readStream.format("rate").load().writeStream.format("noop").start()
Expand Down
32 changes: 29 additions & 3 deletions python/pyspark/sql/tests/streaming/test_streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ def onQueryTerminated(self, event):
"my_event", count(lit(1)).alias("rc"), count(col("error")).alias("erc")
)

q = observed_ds.writeStream.format("console").start()
q = observed_ds.writeStream.format("noop").start()

while q.lastProgress is None or q.lastProgress["batchId"] == 0:
while q.lastProgress is None or q.lastProgress.batchId == 0:
q.awaitTermination(0.5)

time.sleep(5)
Expand All @@ -241,6 +241,32 @@ def onQueryTerminated(self, event):
q.stop()
self.spark.streams.removeListener(error_listener)

def test_streaming_progress(self):
try:
# Test a fancier query with stateful operation and observed metrics
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
df_stateful = df_observe.groupBy().count() # make query stateful
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("update")
.trigger(processingTime="5 seconds")
.start()
)

while q.lastProgress is None or q.lastProgress.batchId == 0:
q.awaitTermination(0.5)

q.stop()

self.check_streaming_query_progress(q.lastProgress, True)
for p in q.recentProgress:
self.check_streaming_query_progress(p, True)

finally:
q.stop()


class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase):
def test_number_of_public_methods(self):
Expand Down Expand Up @@ -355,7 +381,7 @@ def verify(test_listener):
.start()
)
self.assertTrue(q.isActive)
time.sleep(10)
q.awaitTermination(10)
q.stop()

# Make sure all events are empty
Expand Down