Skip to content
Next Next commit
?
  • Loading branch information
WweiL committed Jun 8, 2024
commit c73ebfc9c7e0c43fb80f3a608fd59154e4c946f3
9 changes: 5 additions & 4 deletions python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
QueryProgressEvent,
QueryIdleEvent,
QueryTerminatedEvent,
StreamingQueryProgress,
)
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
Expand Down Expand Up @@ -110,21 +111,21 @@ def status(self) -> Dict[str, Any]:
status.__doc__ = PySparkStreamingQuery.status.__doc__

@property
def recentProgress(self) -> List[Dict[str, Any]]:
def recentProgress(self) -> List[StreamingQueryProgress]:
cmd = pb2.StreamingQueryCommand()
cmd.recent_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
return [json.loads(p) for p in progress]
return [StreamingQueryProgress.fromJson(json.loads(p)) for p in progress]

recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__

@property
def lastProgress(self) -> Optional[Dict[str, Any]]:
def lastProgress(self) -> Optional[StreamingQueryProgress]:
cmd = pb2.StreamingQueryCommand()
cmd.last_progress = True
progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
if len(progress) > 0:
return json.loads(progress[-1])
return StreamingQueryProgress.fromJson(json.loads(progress))
else:
return None

Expand Down
28 changes: 21 additions & 7 deletions python/pyspark/sql/streaming/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,11 @@ def errorClassOnException(self) -> Optional[str]:
return self._errorClassOnException


class StreamingQueryProgress:
class StreamingQueryProgress(dict):
"""
.. versionadded:: 3.4.0
.. versionchanged:: 4.0.0
Becomes a subclass of dict

Notes
-----
Expand Down Expand Up @@ -486,9 +488,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]],
sources=[SourceProgress.fromJson(s) for s in j["sources"]],
sink=SinkProgress.fromJson(j["sink"]),
numInputRows=j["numInputRows"],
inputRowsPerSecond=j["inputRowsPerSecond"],
processedRowsPerSecond=j["processedRowsPerSecond"],
numInputRows=j["numInputRows"] if "numInputRows" in j else None,
inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" in j else None,
processedRowsPerSecond=j["processedRowsPerSecond"] if "processedRowsPerSecond" in j else None,
observedMetrics={
k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows
for k, row_dict in j["observedMetrics"].items()
Expand All @@ -497,6 +499,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
else {},
)

def __getitem__(self, key):
return getattr(self, key)

@property
def id(self) -> uuid.UUID:
"""
Expand Down Expand Up @@ -600,21 +605,30 @@ def numInputRows(self) -> int:
"""
The aggregate (across all sources) number of records processed in a trigger.
"""
return self._numInputRows
if self._numInputRows is not None:
return self._numInputRows
else:
return sum(s.numInputRows for s in self.sources)

@property
def inputRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate of data arriving.
"""
return self._inputRowsPerSecond
if self._inputRowsPerSecond is not None:
return self._inputRowsPerSecond
else:
return sum(s.inputRowsPerSecond for s in self.sources)

@property
def processedRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate at which Spark is processing data.
"""
return self._processedRowsPerSecond
if self._processedRowsPerSecond is not None:
return self._processedRowsPerSecond
else:
return sum(s.processedRowsPerSecond for s in self.sources)

@property
def json(self) -> str:
Expand Down
13 changes: 8 additions & 5 deletions python/pyspark/sql/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from pyspark.errors.exceptions.captured import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.sql.streaming.listener import StreamingQueryListener
from pyspark.sql.streaming.listener import (
StreamingQueryListener,
StreamingQueryProgress,
)

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
Expand Down Expand Up @@ -251,7 +254,7 @@ def status(self) -> Dict[str, Any]:
return json.loads(self._jsq.status().json())

@property
def recentProgress(self) -> List[Dict[str, Any]]:
def recentProgress(self) -> List[StreamingQueryProgress]:
"""
Returns an array of the most recent [[StreamingQueryProgress]] updates for this query.
The number of progress updates retained for each stream is configured by Spark session
Expand Down Expand Up @@ -280,10 +283,10 @@ def recentProgress(self) -> List[Dict[str, Any]]:

>>> sq.stop()
"""
return [json.loads(p.json()) for p in self._jsq.recentProgress()]
return [StreamingQueryProgress.fromJson(json.loads(p)) for p in self._jsq.recentProgress()]

@property
def lastProgress(self) -> Optional[Dict[str, Any]]:
def lastProgress(self) -> Optional[StreamingQueryProgress]:
"""
Returns the most recent :class:`StreamingQueryProgress` update of this streaming query or
None if there were no progress updates
Expand Down Expand Up @@ -311,7 +314,7 @@ def lastProgress(self) -> Optional[Dict[str, Any]]:
"""
lastProgress = self._jsq.lastProgress()
if lastProgress:
return json.loads(lastProgress.json())
return StreamingQueryProgress.fromJson(json.loads(lastProgress.json()))
else:
return None

Expand Down