diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b97ec34b5382..8c17af559c25 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1048,6 +1048,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_grouped_map", "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_parity_python_datasource", + "pyspark.sql.tests.connect.test_parity_python_streaming_datasource", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_client", diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index bd779df4837b..5292e2f92784 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -21,7 +21,6 @@ from typing import IO, Iterator, Tuple from pyspark.accumulators import _accumulatorRegistry -from pyspark.java_gateway import local_connect_and_auth from pyspark.errors import IllegalArgumentException, PySparkAssertionError, PySparkRuntimeError from pyspark.serializers import ( read_int, @@ -37,7 +36,7 @@ StructType, ) from pyspark.sql.worker.plan_data_source_read import records_to_arrow_batches -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, read_command, diff --git a/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py b/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py new file mode 100644 index 000000000000..65bb4c021f4d --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.sql.tests.test_python_streaming_datasource import ( + BasePythonStreamingDataSourceTestsMixin, +) +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class PythonStreamingDataSourceParityTests( + BasePythonStreamingDataSourceTestsMixin, ReusedConnectTestCase +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_python_streaming_datasource import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py index e5622e28f15b..183b0ad80d9d 100644 --- a/python/pyspark/sql/tests/test_python_streaming_datasource.py +++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py @@ -142,15 +142,11 @@ def test_stream_reader(self): self.spark.dataSource.register(self._get_test_data_source()) df = self.spark.readStream.format("TestDataSource").load() - current_batch_id = -1 - def check_batch(df, batch_id): - nonlocal current_batch_id - current_batch_id = batch_id assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) q = df.writeStream.foreachBatch(check_batch).start() - while current_batch_id < 10: + while len(q.recentProgress) < 10: time.sleep(0.2) q.stop() q.awaitTermination()