Skip to content

Commit 2d44848

Browse files
WweiLHyukjinKwon
authored andcommitted
[SPARK-44435][SS][CONNECT] Tests for foreachBatch and Listener
### What changes were proposed in this pull request? Add several new test cases for streaming foreachBatch and streaming query listener events to test various scenarios. ### Why are the changes needed? More tests is better ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test only change Closes #42521 from WweiL/SPARK-44435-tests-foreachBatch-listener. Authored-by: Wei Liu <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent f7a0b3a commit 2d44848

File tree

5 files changed

+144
-27
lines changed

5 files changed

+144
-27
lines changed

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ object StreamingForeachBatchHelper extends Logging {
113113

114114
val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {
115115

116-
// TODO(SPARK-44460): Support Auth credentials
117116
// TODO(SPARK-44462): A new session id pointing to args.df.sparkSession needs to be created.
118117
// This is because MicroBatch execution clones the session during start.
119118
// The session attached to the foreachBatch dataframe is different from the one the one

python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def main(infile: IO, outfile: IO) -> None:
5151
spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate()
5252
spark_connect_session._client._session_id = session_id # type: ignore[attr-defined]
5353

54-
# TODO(SPARK-44460): Pass credentials.
5554
# TODO(SPARK-44461): Enable Process Isolation
5655

5756
func = worker.read_command(pickle_ser, infile)

python/pyspark/sql/connect/streaming/worker/listener_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def main(infile: IO, outfile: IO) -> None:
5959
spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate()
6060
spark_connect_session._client._session_id = session_id # type: ignore[attr-defined]
6161

62-
# TODO(SPARK-44460): Pass credentials.
6362
# TODO(SPARK-44461): Enable Process Isolation
6463

6564
listener = worker.read_command(pickle_ser, infile)

python/pyspark/sql/tests/connect/streaming/test_parity_listener.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,39 +18,31 @@
1818
import unittest
1919
import time
2020

21+
import pyspark.cloudpickle
2122
from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin
22-
from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent
23-
from pyspark.sql.types import StructType, StructField, StringType
23+
from pyspark.sql.streaming.listener import StreamingQueryListener
24+
from pyspark.sql.functions import count, lit
2425
from pyspark.testing.connectutils import ReusedConnectTestCase
2526

2627

27-
def get_start_event_schema():
28-
return StructType(
29-
[
30-
StructField("id", StringType(), True),
31-
StructField("runId", StringType(), True),
32-
StructField("name", StringType(), True),
33-
StructField("timestamp", StringType(), True),
34-
]
35-
)
36-
37-
3828
class TestListener(StreamingQueryListener):
3929
def onQueryStarted(self, event):
40-
df = self.spark.createDataFrame(
41-
data=[(str(event.id), str(event.runId), event.name, event.timestamp)],
42-
schema=get_start_event_schema(),
43-
)
44-
df.write.saveAsTable("listener_start_events")
30+
e = pyspark.cloudpickle.dumps(event)
31+
df = self.spark.createDataFrame(data=[(e,)])
32+
df.write.mode("append").saveAsTable("listener_start_events")
4533

4634
def onQueryProgress(self, event):
47-
pass
35+
e = pyspark.cloudpickle.dumps(event)
36+
df = self.spark.createDataFrame(data=[(e,)])
37+
df.write.mode("append").saveAsTable("listener_progress_events")
4838

4939
def onQueryIdle(self, event):
5040
pass
5141

5242
def onQueryTerminated(self, event):
53-
pass
43+
e = pyspark.cloudpickle.dumps(event)
44+
df = self.spark.createDataFrame(data=[(e,)])
45+
df.write.mode("append").saveAsTable("listener_terminated_events")
5446

5547

5648
class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase):
@@ -65,17 +57,36 @@ def test_listener_events(self):
6557
time.sleep(30)
6658

6759
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
68-
q = df.writeStream.format("noop").queryName("test").start()
60+
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
61+
df_stateful = df_observe.groupBy().count() # make query stateful
62+
q = (
63+
df_stateful.writeStream.format("noop")
64+
.queryName("test")
65+
.outputMode("complete")
66+
.start()
67+
)
6968

7069
self.assertTrue(q.isActive)
7170
time.sleep(10)
71+
self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran
7272
q.stop()
73+
self.assertFalse(q.isActive)
74+
75+
start_event = pyspark.cloudpickle.loads(
76+
self.spark.read.table("listener_start_events").collect()[0][0]
77+
)
78+
79+
progress_event = pyspark.cloudpickle.loads(
80+
self.spark.read.table("listener_progress_events").collect()[0][0]
81+
)
7382

74-
start_event = QueryStartedEvent.fromJson(
75-
self.spark.read.table("listener_start_events").collect()[0].asDict()
83+
terminated_event = pyspark.cloudpickle.loads(
84+
self.spark.read.table("listener_terminated_events").collect()[0][0]
7685
)
7786

7887
self.check_start_event(start_event)
88+
self.check_progress_event(progress_event)
89+
self.check_terminated_event(terminated_event)
7990

8091
finally:
8192
self.spark.streams.removeListener(test_listener)

python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
#
1717

1818
import time
19-
19+
from pyspark.sql.dataframe import DataFrame
2020
from pyspark.testing.sqlutils import ReusedSQLTestCase
2121

2222

23+
def my_test_function_1():
24+
return 1
25+
26+
2327
class StreamingTestsForeachBatchMixin:
2428
def test_streaming_foreachBatch(self):
2529
q = None
@@ -88,6 +92,111 @@ def func(batch_df, _):
8892
q.stop()
8993
self.assertIsNone(q.exception(), "No exception has to be propagated.")
9094

95+
def test_streaming_foreachBatch_spark_session(self):
96+
table_name = "testTable_foreachBatch"
97+
98+
def func(df: DataFrame, batch_id: int):
99+
if batch_id > 0: # only process once
100+
return
101+
spark = df.sparkSession
102+
df1 = spark.createDataFrame([("structured",), ("streaming",)])
103+
df1.union(df).write.mode("append").saveAsTable(table_name)
104+
105+
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
106+
q = df.writeStream.foreachBatch(func).start()
107+
q.processAllAvailable()
108+
q.stop()
109+
110+
actual = self.spark.read.table(table_name)
111+
df = (
112+
self.spark.read.format("text")
113+
.load(path="python/test_support/sql/streaming/")
114+
.union(self.spark.createDataFrame([("structured",), ("streaming",)]))
115+
)
116+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
117+
118+
def test_streaming_foreachBatch_path_access(self):
119+
table_name = "testTable_foreachBatch_path"
120+
121+
def func(df: DataFrame, batch_id: int):
122+
if batch_id > 0: # only process once
123+
return
124+
spark = df.sparkSession
125+
df1 = spark.read.format("text").load("python/test_support/sql/streaming")
126+
df1.union(df).write.mode("append").saveAsTable(table_name)
127+
128+
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
129+
q = df.writeStream.foreachBatch(func).start()
130+
q.processAllAvailable()
131+
q.stop()
132+
133+
actual = self.spark.read.table(table_name)
134+
df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/")
135+
df = df.union(df)
136+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
137+
138+
# write to delta table?
139+
140+
@staticmethod
141+
def my_test_function_2():
142+
return 2
143+
144+
def test_streaming_foreachBatch_fuction_calling(self):
145+
def my_test_function_3():
146+
return 3
147+
148+
table_name = "testTable_foreachBatch_function"
149+
150+
def func(df: DataFrame, batch_id: int):
151+
if batch_id > 0: # only process once
152+
return
153+
spark = df.sparkSession
154+
df1 = spark.createDataFrame(
155+
[
156+
(my_test_function_1(),),
157+
(StreamingTestsForeachBatchMixin.my_test_function_2(),),
158+
(my_test_function_3(),),
159+
]
160+
)
161+
df1.write.mode("append").saveAsTable(table_name)
162+
163+
df = self.spark.readStream.format("rate").load()
164+
q = df.writeStream.foreachBatch(func).start()
165+
q.processAllAvailable()
166+
q.stop()
167+
168+
actual = self.spark.read.table(table_name)
169+
df = self.spark.createDataFrame(
170+
[
171+
(my_test_function_1(),),
172+
(StreamingTestsForeachBatchMixin.my_test_function_2(),),
173+
(my_test_function_3(),),
174+
]
175+
)
176+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
177+
178+
def test_streaming_foreachBatch_import(self):
179+
import time # not imported in foreachBatch_worker
180+
181+
table_name = "testTable_foreachBatch_import"
182+
183+
def func(df: DataFrame, batch_id: int):
184+
if batch_id > 0: # only process once
185+
return
186+
time.sleep(1)
187+
spark = df.sparkSession
188+
df1 = spark.read.format("text").load("python/test_support/sql/streaming")
189+
df1.write.mode("append").saveAsTable(table_name)
190+
191+
df = self.spark.readStream.format("rate").load()
192+
q = df.writeStream.foreachBatch(func).start()
193+
q.processAllAvailable()
194+
q.stop()
195+
196+
actual = self.spark.read.table(table_name)
197+
df = self.spark.read.format("text").load("python/test_support/sql/streaming")
198+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
199+
91200

92201
class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase):
93202
pass

0 commit comments

Comments
 (0)