Skip to content
Prev Previous commit
Next Next commit
move asDict methods to test suite
  • Loading branch information
WweiL committed Aug 21, 2023
commit fa8be5cb7a4105c0668de389cdd622fcfbfbf28b
67 changes: 0 additions & 67 deletions python/pyspark/sql/streaming/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@
from py4j.java_gateway import JavaObject

from pyspark.sql import Row
from pyspark.sql.types import (
ArrayType,
StructType,
StructField,
StringType,
IntegerType,
FloatType,
MapType,
)
from pyspark import cloudpickle

__all__ = ["StreamingQueryListener"]
Expand Down Expand Up @@ -206,15 +197,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "QueryStartedEvent":
timestamp=j["timestamp"],
)

def asDict(self) -> Dict[str, Any]:
def conv(obj: Any) -> Any:
if isinstance(obj, uuid.UUID):
return str(obj)
else:
return obj

return {k[1:]: conv(v) for k, v in self.__dict__.items()}

@property
def id(self) -> uuid.UUID:
"""
Expand Down Expand Up @@ -275,9 +257,6 @@ def progress(self) -> "StreamingQueryProgress":
"""
return self._progress

def asDict(self) -> Dict[str, Any]:
return {"progress": self.progress.asDict()}


class QueryIdleEvent:
"""
Expand Down Expand Up @@ -307,15 +286,6 @@ def fromJObject(cls, jevent: JavaObject) -> "QueryIdleEvent":
def fromJson(cls, j: Dict[str, Any]) -> "QueryIdleEvent":
return cls(id=uuid.UUID(j["id"]), runId=uuid.UUID(j["runId"]), timestamp=j["timestamp"])

def asDict(self) -> Dict[str, Any]:
def conv(obj: Any) -> Any:
if isinstance(obj, uuid.UUID):
return str(obj)
else:
return obj

return {k[1:]: conv(v) for k, v in self.__dict__.items()}

@property
def id(self) -> uuid.UUID:
"""
Expand Down Expand Up @@ -383,15 +353,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "QueryTerminatedEvent":
errorClassOnException=j["errorClassOnException"],
)

def asDict(self) -> Dict[str, Any]:
def conv(obj: Any) -> Any:
if isinstance(obj, uuid.UUID):
return str(obj)
else:
return obj

return {k[1:]: conv(v) for k, v in self.__dict__.items()}

@property
def id(self) -> uuid.UUID:
"""
Expand Down Expand Up @@ -535,25 +496,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
else {},
)

def asDict(self) -> Dict[str, Any]:
def conv(obj: Any) -> Any:
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, (SourceProgress, SinkProgress, StateOperatorProgress)):
return obj.asDict()
elif isinstance(obj, Row):
return json.dumps(obj.asDict()) # Assume no nested row in observed metrics
elif isinstance(obj, list):
return [conv(o) for o in obj]
elif isinstance(obj, dict):
return dict((k, conv(v)) for k, v in obj.items())
else:
return obj

return {
k[1:]: conv(v) for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]
}

@property
def id(self) -> uuid.UUID:
"""
Expand Down Expand Up @@ -776,9 +718,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress":
customMetrics=dict(j["customMetrics"]) if "customMetrics" in j else {},
)

def asDict(self) -> Dict[str, Any]:
return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]}

@property
def operatorName(self) -> str:
return self._operatorName
Expand Down Expand Up @@ -914,9 +853,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress":
metrics=dict(j["metrics"]) if "metrics" in j else {},
)

def asDict(self) -> Dict[str, Any]:
return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]}

@property
def description(self) -> str:
"""
Expand Down Expand Up @@ -1028,9 +964,6 @@ def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress":
metrics=dict(jprogress.metrics()),
)

def asDict(self) -> Dict[str, Any]:
return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]}

@classmethod
def fromJson(cls, j: Dict[str, Any]) -> "SinkProgress":
return cls(
Expand Down
58 changes: 53 additions & 5 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

import unittest
import time
import uuid
import json
from typing import Any, Dict, Union

from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin
from pyspark.sql.streaming.listener import (
Expand All @@ -25,6 +28,10 @@
QueryProgressEvent,
QueryIdleEvent,
QueryTerminatedEvent,
StateOperatorProgress,
StreamingQueryProgress,
SourceProgress,
SinkProgress,
)
from pyspark.sql.types import (
ArrayType,
Expand All @@ -35,10 +42,51 @@
FloatType,
MapType,
)
from pyspark.sql import Row
from pyspark.sql.functions import count, lit
from pyspark.testing.connectutils import ReusedConnectTestCase


def listener_event_as_dict(
e: Union[QueryStartedEvent, QueryProgressEvent, QueryIdleEvent, QueryTerminatedEvent]
) -> Dict[str, Any]:
if isinstance(e, QueryProgressEvent):
return {"progress": streaming_query_progress_as_dict(e.progress)}
else:

def conv(obj: Any) -> Any:
if isinstance(obj, uuid.UUID):
return str(obj)
else:
return obj

return {k[1:]: conv(v) for k, v in e.__dict__.items()}


def streaming_query_progress_as_dict(e: StreamingQueryProgress) -> Dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simpler way might be pyspark.cloupickle.dumps(event), save that as a table, and load it back, and unpickle it via pyspark.cloudpickle.loads(binary) and compare them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks! Never thought of that

def conv(obj: Any) -> Any:
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, (SourceProgress, SinkProgress, StateOperatorProgress)):
return other_progress_as_dict(obj)
elif isinstance(obj, Row):
return json.dumps(obj.asDict()) # Assume no nested row in observed metrics
elif isinstance(obj, list):
return [conv(o) for o in obj]
elif isinstance(obj, dict):
return dict((k, conv(v)) for k, v in obj.items())
else:
return obj

return {k[1:]: conv(v) for k, v in e.__dict__.items() if k not in ["_jprogress", "_jdict"]}


def other_progress_as_dict(
e: Union[StateOperatorProgress, SourceProgress, SinkProgress]
) -> Dict[str, Any]:
return {k[1:]: v for k, v in e.__dict__.items() if k not in ["_jprogress", "_jdict"]}


def get_start_event_schema():
return StructType(
[
Expand Down Expand Up @@ -147,14 +195,14 @@ def get_progress_event_schema():
class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
df = self.spark.createDataFrame(
data=[(event.asDict())],
data=[listener_event_as_dict(event)],
schema=get_start_event_schema(),
)
df.write.saveAsTable("listener_start_events")
df.write.mode("append").saveAsTable("listener_start_events")

def onQueryProgress(self, event):
df = self.spark.createDataFrame(
data=[event.asDict()],
data=[listener_event_as_dict(event)],
schema=get_progress_event_schema(),
)
df.write.mode("append").saveAsTable("listener_progress_events")
Expand All @@ -164,10 +212,10 @@ def onQueryIdle(self, event):

def onQueryTerminated(self, event):
df = self.spark.createDataFrame(
data=[event.asDict()],
data=[listener_event_as_dict(event)],
schema=get_terminated_event_schema(),
)
df.write.saveAsTable("listener_terminated_events")
df.write.mode("append").saveAsTable("listener_terminated_events")


class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase):
Expand Down