Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
tmp
  • Loading branch information
WweiL committed Jun 7, 2024
commit 89c8b70d3118f47648032bf54af98c8492abd24e
29 changes: 22 additions & 7 deletions python/pyspark/sql/streaming/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,11 @@ def errorClassOnException(self) -> Optional[str]:
return self._errorClassOnException


class StreamingQueryProgress:
class StreamingQueryProgress(dict):
"""
.. versionadded:: 3.4.0
.. versionchanged:: 4.0.0
Becomes a subclass of dict

Notes
-----
Expand Down Expand Up @@ -473,6 +475,10 @@ def fromJObject(cls, jprogress: "JavaObject") -> "StreamingQueryProgress":

@classmethod
def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
num_input_rows = j.get("numInputRows", None)
input_rows_per_sec = j.get("inputRowsPerSecond", None)
processed_rows_per_sec = j.get("processedRowsPerSecond", None)

return cls(
jdict=j,
id=uuid.UUID(j["id"]),
Expand All @@ -486,9 +492,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]],
sources=[SourceProgress.fromJson(s) for s in j["sources"]],
sink=SinkProgress.fromJson(j["sink"]),
numInputRows=j["numInputRows"],
inputRowsPerSecond=j["inputRowsPerSecond"],
processedRowsPerSecond=j["processedRowsPerSecond"],
numInputRows=j["numInputRows"] if "numInputRows" in j else None,
inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" in j else None,
processedRowsPerSecond=j["processedRowsPerSecond"] if "processedRowsPerSecond" in j else None,
observedMetrics={
k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows
for k, row_dict in j["observedMetrics"].items()
Expand Down Expand Up @@ -600,21 +606,30 @@ def numInputRows(self) -> int:
"""
The aggregate (across all sources) number of records processed in a trigger.
"""
return self._numInputRows
if self._numInputRows is not None:
return self._numInputRows
else:
return sum(s.numInputRows for s in self.sources)

@property
def inputRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate of data arriving.
"""
return self._inputRowsPerSecond
if self._inputRowsPerSecond is not None:
return self._inputRowsPerSecond
else:
return sum(s.inputRowsPerSecond for s in self.sources)

@property
def processedRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate at which Spark is processing data.
"""
return self._processedRowsPerSecond
if self._processedRowsPerSecond is not None:
return self._processedRowsPerSecond
else:
return sum(s.processedRowsPerSecond for s in self.sources)

@property
def json(self) -> str:
Expand Down