diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py index 5d2c1bbbf62c..ef286115a303 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py @@ -97,46 +97,48 @@ def func(batch_df, _): def test_streaming_foreach_batch_spark_session(self): table_name = "testTable_foreach_batch" + with self.table(table_name): - def func(df: DataFrame, batch_id: int): - if batch_id > 0: # only process once - return - spark = df.sparkSession - df1 = spark.createDataFrame([("structured",), ("streaming",)]) - df1.union(df).write.mode("append").saveAsTable(table_name) + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return + spark = df.sparkSession + df1 = spark.createDataFrame([("structured",), ("streaming",)]) + df1.union(df).write.mode("append").saveAsTable(table_name) - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(func).start() - q.processAllAvailable() - q.stop() + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() - actual = self.spark.read.table(table_name) - df = ( - self.spark.read.format("text") - .load(path="python/test_support/sql/streaming/") - .union(self.spark.createDataFrame([("structured",), ("streaming",)])) - ) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + actual = self.spark.read.table(table_name) + df = ( + self.spark.read.format("text") + .load(path="python/test_support/sql/streaming/") + .union(self.spark.createDataFrame([("structured",), ("streaming",)])) + ) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) def test_streaming_foreach_batch_path_access(self): table_name = "testTable_foreach_batch_path" + with self.table(table_name): - def func(df: DataFrame, batch_id: int): - if batch_id > 0: # only process once - return - spark = df.sparkSession - df1 = spark.read.format("text").load("python/test_support/sql/streaming") - df1.union(df).write.mode("append").saveAsTable(table_name) + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return + spark = df.sparkSession + df1 = spark.read.format("text").load("python/test_support/sql/streaming") + df1.union(df).write.mode("append").saveAsTable(table_name) - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(func).start() - q.processAllAvailable() - q.stop() + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() - actual = self.spark.read.table(table_name) - df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/") - df = df.union(df) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + actual = self.spark.read.table(table_name) + df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/") + df = df.union(df) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) @staticmethod def my_test_function_2(): @@ -147,56 +149,58 @@ def my_test_function_3(): return 3 table_name = "testTable_foreach_batch_function" + with self.table(table_name): + + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return + spark = df.sparkSession + df1 = spark.createDataFrame( + [ + (my_test_function_1(),), + (StreamingTestsForeachBatchMixin.my_test_function_2(),), + (my_test_function_3(),), + ] + ) + df1.write.mode("append").saveAsTable(table_name) + + df = self.spark.readStream.format("rate").load() + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() - def func(df: DataFrame, batch_id: int): - if batch_id > 0: # only process once - return - spark = df.sparkSession - df1 = spark.createDataFrame( + actual = self.spark.read.table(table_name) + df = self.spark.createDataFrame( [ (my_test_function_1(),), (StreamingTestsForeachBatchMixin.my_test_function_2(),), (my_test_function_3(),), ] ) - df1.write.mode("append").saveAsTable(table_name) - - df = self.spark.readStream.format("rate").load() - q = df.writeStream.foreachBatch(func).start() - q.processAllAvailable() - q.stop() - - actual = self.spark.read.table(table_name) - df = self.spark.createDataFrame( - [ - (my_test_function_1(),), - (StreamingTestsForeachBatchMixin.my_test_function_2(),), - (my_test_function_3(),), - ] - ) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) def test_streaming_foreach_batch_import(self): import time # not imported in foreach_batch_worker table_name = "testTable_foreach_batch_import" + with self.table(table_name): + + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return + time.sleep(1) + spark = df.sparkSession + df1 = spark.read.format("text").load("python/test_support/sql/streaming") + df1.write.mode("append").saveAsTable(table_name) + + df = self.spark.readStream.format("rate").load() + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() - def func(df: DataFrame, batch_id: int): - if batch_id > 0: # only process once - return - time.sleep(1) - spark = df.sparkSession - df1 = spark.read.format("text").load("python/test_support/sql/streaming") - df1.write.mode("append").saveAsTable(table_name) - - df = self.spark.readStream.format("rate").load() - q = df.writeStream.foreachBatch(func).start() - q.processAllAvailable() - q.stop() - - actual = self.spark.read.table(table_name) - df = self.spark.read.format("text").load("python/test_support/sql/streaming") - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + actual = self.spark.read.table(table_name) + df = self.spark.read.format("text").load("python/test_support/sql/streaming") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase):