Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ object StreamingForeachBatchHelper extends Logging {

val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {

// TODO(SPARK-44460): Support Auth credentials
// TODO(SPARK-44462): A new session id pointing to args.df.sparkSession needs to be created.
// This is because MicroBatch execution clones the session during start.
// The session attached to the foreachBatch dataframe is different from the one the one
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def main(infile: IO, outfile: IO) -> None:
spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate()
spark_connect_session._client._session_id = session_id # type: ignore[attr-defined]

# TODO(SPARK-44460): Pass credentials.
# TODO(SPARK-44461): Enable Process Isolation

func = worker.read_command(pickle_ser, infile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def main(infile: IO, outfile: IO) -> None:
spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate()
spark_connect_session._client._session_id = session_id # type: ignore[attr-defined]

# TODO(SPARK-44460): Pass credentials.
# TODO(SPARK-44461): Enable Process Isolation

listener = worker.read_command(pickle_ser, infile)
Expand Down
57 changes: 34 additions & 23 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,31 @@
import unittest
import time

import pyspark.cloudpickle
from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin
from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.streaming.listener import StreamingQueryListener
from pyspark.sql.functions import count, lit
from pyspark.testing.connectutils import ReusedConnectTestCase


def get_start_event_schema():
return StructType(
[
StructField("id", StringType(), True),
StructField("runId", StringType(), True),
StructField("name", StringType(), True),
StructField("timestamp", StringType(), True),
]
)


class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
df = self.spark.createDataFrame(
data=[(str(event.id), str(event.runId), event.name, event.timestamp)],
schema=get_start_event_schema(),
)
df.write.saveAsTable("listener_start_events")
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_start_events")

def onQueryProgress(self, event):
pass
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_progress_events")

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
pass
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_terminated_events")


class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase):
Expand All @@ -65,17 +57,36 @@ def test_listener_events(self):
time.sleep(30)

df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
q = df.writeStream.format("noop").queryName("test").start()
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
df_stateful = df_observe.groupBy().count() # make query stateful
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("complete")
.start()
)

self.assertTrue(q.isActive)
time.sleep(10)
self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran
q.stop()
self.assertFalse(q.isActive)

start_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_start_events").collect()[0][0]
)

progress_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_progress_events").collect()[0][0]
)

start_event = QueryStartedEvent.fromJson(
self.spark.read.table("listener_start_events").collect()[0].asDict()
terminated_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_terminated_events").collect()[0][0]
)

self.check_start_event(start_event)
self.check_progress_event(progress_event)
self.check_terminated_event(terminated_event)

finally:
self.spark.streams.removeListener(test_listener)
Expand Down
111 changes: 110 additions & 1 deletion python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
#

import time

from pyspark.sql.dataframe import DataFrame
from pyspark.testing.sqlutils import ReusedSQLTestCase


def my_test_function_1():
return 1


class StreamingTestsForeachBatchMixin:
def test_streaming_foreachBatch(self):
q = None
Expand Down Expand Up @@ -88,6 +92,111 @@ def func(batch_df, _):
q.stop()
self.assertIsNone(q.exception(), "No exception has to be propagated.")

def test_streaming_foreachBatch_spark_session(self):
table_name = "testTable_foreachBatch"

def func(df: DataFrame, batch_id: int):
if batch_id > 0: # only process once
return
spark = df.sparkSession
df1 = spark.createDataFrame([("structured",), ("streaming",)])
df1.union(df).write.mode("append").saveAsTable(table_name)

df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()

actual = self.spark.read.table(table_name)
df = (
self.spark.read.format("text")
.load(path="python/test_support/sql/streaming/")
.union(self.spark.createDataFrame([("structured",), ("streaming",)]))
)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))

def test_streaming_foreachBatch_path_access(self):
table_name = "testTable_foreachBatch_path"

def func(df: DataFrame, batch_id: int):
if batch_id > 0: # only process once
return
spark = df.sparkSession
df1 = spark.read.format("text").load("python/test_support/sql/streaming")
df1.union(df).write.mode("append").saveAsTable(table_name)

df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()

actual = self.spark.read.table(table_name)
df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/")
df = df.union(df)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))

# write to delta table?

@staticmethod
def my_test_function_2():
return 2

def test_streaming_foreachBatch_fuction_calling(self):
def my_test_function_3():
return 3

table_name = "testTable_foreachBatch_function"

def func(df: DataFrame, batch_id: int):
if batch_id > 0: # only process once
return
spark = df.sparkSession
df1 = spark.createDataFrame(
[
(my_test_function_1(),),
(StreamingTestsForeachBatchMixin.my_test_function_2(),),
(my_test_function_3(),),
]
)
df1.write.mode("append").saveAsTable(table_name)

df = self.spark.readStream.format("rate").load()
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()

actual = self.spark.read.table(table_name)
df = self.spark.createDataFrame(
[
(my_test_function_1(),),
(StreamingTestsForeachBatchMixin.my_test_function_2(),),
(my_test_function_3(),),
]
)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))

def test_streaming_foreachBatch_import(self):
import time # not imported in foreachBatch_worker

table_name = "testTable_foreachBatch_import"

def func(df: DataFrame, batch_id: int):
if batch_id > 0: # only process once
return
time.sleep(1)
spark = df.sparkSession
df1 = spark.read.format("text").load("python/test_support/sql/streaming")
df1.write.mode("append").saveAsTable(table_name)

df = self.spark.readStream.format("rate").load()
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()

actual = self.spark.read.table(table_name)
df = self.spark.read.format("text").load("python/test_support/sql/streaming")
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))


class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase):
pass
Expand Down