Skip to content
9 changes: 5 additions & 4 deletions python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
QueryProgressEvent,
QueryIdleEvent,
QueryTerminatedEvent,
StreamingQueryProgress,
)
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
Expand Down Expand Up @@ -110,21 +111,21 @@ def status(self) -> Dict[str, Any]:
status.__doc__ = PySparkStreamingQuery.status.__doc__

@property
def recentProgress(self) -> List[Dict[str, Any]]:
def recentProgress(self) -> List[StreamingQueryProgress]:
cmd = pb2.StreamingQueryCommand()
cmd.recent_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
return [json.loads(p) for p in progress]
return [StreamingQueryProgress.fromJson(json.loads(p)) for p in progress]

recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__

@property
def lastProgress(self) -> Optional[Dict[str, Any]]:
def lastProgress(self) -> Optional[StreamingQueryProgress]:
cmd = pb2.StreamingQueryCommand()
cmd.last_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
if len(progress) > 0:
return json.loads(progress[-1])
return StreamingQueryProgress.fromJson(json.loads(progress[-1]))
else:
return None

Expand Down
41 changes: 34 additions & 7 deletions python/pyspark/sql/streaming/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,11 @@ def errorClassOnException(self) -> Optional[str]:
return self._errorClassOnException


class StreamingQueryProgress:
class StreamingQueryProgress(dict):
"""
.. versionadded:: 3.4.0
.. versionchanged:: 4.0.0
Becomes a subclass of dict

Notes
-----
Expand Down Expand Up @@ -486,9 +488,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]],
sources=[SourceProgress.fromJson(s) for s in j["sources"]],
sink=SinkProgress.fromJson(j["sink"]),
numInputRows=j["numInputRows"],
inputRowsPerSecond=j["inputRowsPerSecond"],
processedRowsPerSecond=j["processedRowsPerSecond"],
numInputRows=j["numInputRows"] if "numInputRows" in j else None,
inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" in j else None,
processedRowsPerSecond=j["processedRowsPerSecond"] if "processedRowsPerSecond" in j else None,
observedMetrics={
k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows
for k, row_dict in j["observedMetrics"].items()
Expand All @@ -497,6 +499,19 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
else {},
)

def __getitem__(self, key):
# Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId
# to string. To prevent breaking change, also cast them to string when accessed with
# __getitem__.
if key == "id" or key == "runId":
Copy link
Contributor Author

@WweiL WweiL Jun 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is really needed. But if we delete this if, now "query.lastProgress["id"]" would return type uuid, before it was string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there would be lots of breaking changes (e.g. now the sources method also return the actual SourceProgress

def sources(self) -> List["SourceProgress"]:

let me also make these subclass of dict...

return str(getattr(self, key))
else:
return getattr(self, key)

def __setitem__(self, key, value):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in the fear of users ever set the value of the returned dict before this change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm .. but the end users can't access to this value if I am reading this correctly?

Copy link
Contributor Author

@WweiL WweiL Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fear is backward compatibility. This is possible in current master:

>>> q = spark.readStream.format("rate").load().writeStream.format("noop").start()
24/06/10 16:10:35 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /private/var/folders/9k/pbxb4_690wv4smwhwbzwmqkw0000gp/T/temporary-709975db-23ed-4838-b9ae-93a7ffe59183. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
24/06/10 16:10:35 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
>>> p = q.lastProgress
>>> p
{'id': '44510846-29f8-4218-95cf-616efecadb05', 'runId': 'afcac0a7-424b-428b-948e-2c0fc21a43a2', 'name': None, 'timestamp': '2024-06-10T23:10:38.257Z', 'batchId': 2, 'batchDuration': 215, 'numInputRows': 1, 'inputRowsPerSecond': 76.92307692307692, 'processedRowsPerSecond': 4.651162790697675, 'durationMs': {'addBatch': 30, 'commitOffsets': 82, 'getBatch': 0, 'latestOffset': 0, 'queryPlanning': 4, 'triggerExecution': 215, 'walCommit': 98}, 'stateOperators': [], 'sources': [{'description': 'RateStreamV2[rowsPerSecond=1, rampUpTimeSeconds=0, numPartitions=default', 'startOffset': 1, 'endOffset': 2, 'latestOffset': 2, 'numInputRows': 1, 'inputRowsPerSecond': 76.92307692307692, 'processedRowsPerSecond': 4.651162790697675}], 'sink': {'description': 'org.apache.spark.sql.execution.datasources.noop.NoopTable$@67a2b2a4', 'numOutputRows': 1}}
>>> p["id"]
'44510846-29f8-4218-95cf-616efecadb05'
>>> p["id"] = "aaaaaaa"
>>> p["id"]
'aaaaaaa'

This is not possible in Scala of course, but not sure if we should keep this python specific behavior....

internal_key = "_" + key
setattr(self, internal_key, value)

@property
def id(self) -> uuid.UUID:
"""
Expand Down Expand Up @@ -600,21 +615,30 @@ def numInputRows(self) -> int:
"""
The aggregate (across all sources) number of records processed in a trigger.
"""
return self._numInputRows
if self._numInputRows is not None:
return self._numInputRows
else:
return sum(s.numInputRows for s in self.sources)

@property
def inputRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate of data arriving.
"""
return self._inputRowsPerSecond
if self._inputRowsPerSecond is not None:
return self._inputRowsPerSecond
else:
return sum(s.inputRowsPerSecond for s in self.sources)

@property
def processedRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate at which Spark is processing data.
"""
return self._processedRowsPerSecond
if self._processedRowsPerSecond is not None:
return self._processedRowsPerSecond
else:
return sum(s.processedRowsPerSecond for s in self.sources)

@property
def json(self) -> str:
Expand All @@ -641,6 +665,9 @@ def prettyJson(self) -> str:
def __str__(self) -> str:
return self.prettyJson

def __repr__(self) -> str:
return self.prettyJson


class StateOperatorProgress:
"""
Expand Down
13 changes: 8 additions & 5 deletions python/pyspark/sql/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from pyspark.errors.exceptions.captured import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.sql.streaming.listener import StreamingQueryListener
from pyspark.sql.streaming.listener import (
StreamingQueryListener,
StreamingQueryProgress,
)

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
Expand Down Expand Up @@ -251,7 +254,7 @@ def status(self) -> Dict[str, Any]:
return json.loads(self._jsq.status().json())

@property
def recentProgress(self) -> List[Dict[str, Any]]:
def recentProgress(self) -> List[StreamingQueryProgress]:
"""
Returns an array of the most recent [[StreamingQueryProgress]] updates for this query.
The number of progress updates retained for each stream is configured by Spark session
Expand Down Expand Up @@ -280,10 +283,10 @@ def recentProgress(self) -> List[Dict[str, Any]]:

>>> sq.stop()
"""
return [json.loads(p.json()) for p in self._jsq.recentProgress()]
return [StreamingQueryProgress.fromJObject(p) for p in self._jsq.recentProgress()]

@property
def lastProgress(self) -> Optional[Dict[str, Any]]:
def lastProgress(self) -> Optional[StreamingQueryProgress]:
"""
Returns the most recent :class:`StreamingQueryProgress` update of this streaming query or
None if there were no progress updates
Expand Down Expand Up @@ -311,7 +314,7 @@ def lastProgress(self) -> Optional[Dict[str, Any]]:
"""
lastProgress = self._jsq.lastProgress()
if lastProgress:
return json.loads(lastProgress.json())
return StreamingQueryProgress.fromJObject(lastProgress)
else:
return None

Expand Down
28 changes: 27 additions & 1 deletion python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class StreamingTestsMixin:
def test_streaming_query_functions_basic(self):
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = (
df.writeStream.format("memory")
.queryName("test_streaming_query_functions_basic")
Expand All @@ -46,6 +46,7 @@ def test_streaming_query_functions_basic(self):
lastProgress = query.lastProgress
self.assertEqual(lastProgress["name"], query.name)
self.assertEqual(lastProgress["id"], query.id)
# SPARK-48567 Use attribute to access progress
self.assertTrue(any(p == lastProgress for p in recentProgress))
query.explain()

Expand All @@ -58,6 +59,31 @@ def test_streaming_query_functions_basic(self):
finally:
query.stop()

def test_streaming_progress(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = df.writeStream.format("noop").start()
try:
query.processAllAvailable()
recentProgress = query.recentProgress
lastProgress = query.lastProgress
self.assertEqual(lastProgress["name"], query.name)
self.assertEqual(lastProgress["id"], query.id)
# SPARK-48567 Use attribute to access fields in q.lastProgress
self.assertEqual(lastProgress.name, query.name)
self.assertEqual(str(lastProgress.id), query.id)
new_name = "myNewQuery"
lastProgress["name"] = new_name
self.assertEqual(lastProgress.name, new_name)
self.assertTrue(any(p == lastProgress for p in recentProgress))

except Exception as e:
self.fail(
"Streaming query functions sanity check shouldn't throw any error. "
"Error message: " + str(e)
)
finally:
query.stop()

def test_stream_trigger(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")

Expand Down
34 changes: 30 additions & 4 deletions python/pyspark/sql/tests/streaming/test_streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def check_streaming_query_progress(self, progress, is_stateful):

self.assertTrue(isinstance(progress.sink, SinkProgress))
self.check_sink_progress(progress.sink)
self.assertTrue(isinstance(progress.observedMetrics, dict))
self.assertTrue(isinstance(progress.observedMetrics, Row))

def check_state_operator_progress(self, progress):
"""Check StateOperatorProgress"""
Expand Down Expand Up @@ -227,9 +227,9 @@ def onQueryTerminated(self, event):
"my_event", count(lit(1)).alias("rc"), count(col("error")).alias("erc")
)

q = observed_ds.writeStream.format("console").start()
q = observed_ds.writeStream.format("noop").start()

while q.lastProgress is None or q.lastProgress["batchId"] == 0:
while q.lastProgress is None or q.lastProgress.batchId == 0:
q.awaitTermination(0.5)

time.sleep(5)
Expand All @@ -241,6 +241,32 @@ def onQueryTerminated(self, event):
q.stop()
self.spark.streams.removeListener(error_listener)

def test_streaming_progress(self):
try:
# Test a fancier query with stateful operation and observed metrics
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
df_stateful = df_observe.groupBy().count() # make query stateful
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("update")
.trigger(processingTime="5 seconds")
.start()
)

while q.lastProgress is None or q.lastProgress.batchId == 0:
q.awaitTermination(0.5)

q.stop()

self.check_streaming_query_progress(q.lastProgress, True)
for p in q.recentProgress:
self.check_streaming_query_progress(p, True)

finally:
q.stop()


class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase):
def test_number_of_public_methods(self):
Expand Down Expand Up @@ -355,7 +381,7 @@ def verify(test_listener):
.start()
)
self.assertTrue(q.isActive)
time.sleep(10)
q.awaitTermination(10)
q.stop()

# Make sure all events are empty
Expand Down