From 94f6248af52d8ded3a7a3510038b970fa7b1b53c Mon Sep 17 00:00:00 2001 From: Chaoqin Li Date: Thu, 11 Apr 2024 10:10:41 -0700 Subject: [PATCH 1/2] [SPARK-47777][PYTHON][SS][TESTS] Add spark connect test for python streaming data source ### What changes were proposed in this pull request? Make python streaming data source pyspark test also runs on spark connect. Refactor the test because foreachbatch runs on spark connect server and do not update local variable. ### Why are the changes needed? Test python streaming data source on spark connect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Test change. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45950 from chaoqin-li1123/source_spark_connect. Authored-by: Chaoqin Li Signed-off-by: Dongjoon Hyun --- dev/sparktestsupport/modules.py | 1 + ...test_parity_python_streaming_datasource.py | 39 +++++++++++++++++++ .../tests/test_python_streaming_datasource.py | 6 +-- 3 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 701203414702..5e169eb119b4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1047,6 +1047,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_grouped_map", "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_parity_python_datasource", + "pyspark.sql.tests.connect.test_parity_python_streaming_datasource", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_client", diff --git a/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py b/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py new file mode 100644 index 000000000000..65bb4c021f4d --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.sql.tests.test_python_streaming_datasource import ( + BasePythonStreamingDataSourceTestsMixin, +) +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class PythonStreamingDataSourceParityTests( + BasePythonStreamingDataSourceTestsMixin, ReusedConnectTestCase +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_python_streaming_datasource import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py index 53fb5ca9381f..f7247599be83 100644 --- a/python/pyspark/sql/tests/test_python_streaming_datasource.py +++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py @@ -140,15 +140,11 @@ def test_stream_reader(self): self.spark.dataSource.register(self._get_test_data_source()) df = self.spark.readStream.format("TestDataSource").load() - current_batch_id = -1 - def check_batch(df, batch_id): - nonlocal current_batch_id - current_batch_id = batch_id assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) q = df.writeStream.foreachBatch(check_batch).start() - while current_batch_id < 10: + while len(q.recentProgress) < 10: time.sleep(0.2) q.stop() q.awaitTermination From 720a3a29ff02fe96403a044b3691add7bab787fa Mon Sep 17 00:00:00 2001 From: Chaoqin Li Date: Mon, 10 Jun 2024 12:07:54 -0700 Subject: [PATCH 2/2] fix --- python/pyspark/sql/streaming/python_streaming_source_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index bd779df4837b..5292e2f92784 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -21,7 +21,6 @@ from typing import IO, Iterator, Tuple from pyspark.accumulators import _accumulatorRegistry -from pyspark.java_gateway import local_connect_and_auth from pyspark.errors import IllegalArgumentException, PySparkAssertionError, PySparkRuntimeError from pyspark.serializers import ( read_int, @@ -37,7 +36,7 @@ StructType, ) from pyspark.sql.worker.plan_data_source_read import records_to_arrow_batches -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, read_command,