From ce5757c0514eff03fe3a3ff7114fdafd153093d9 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 18 Jun 2024 13:11:31 -0700 Subject: [PATCH 1/2] fix --- .../python_streaming_source_runner.py | 4 +- .../sql/tests/test_python_datasource.py | 46 ++++++++++++++++++- .../sql/worker/commit_data_source_write.py | 11 +---- .../pyspark/sql/worker/create_data_source.py | 8 ++-- .../sql/worker/plan_data_source_read.py | 24 ++++++---- .../worker/python_streaming_sink_runner.py | 6 +-- .../sql/worker/write_into_data_source.py | 15 +++++- 7 files changed, 85 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index 5292e2f92784..754ecff61b97 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -130,7 +130,7 @@ def main(infile: IO, outfile: IO) -> None: if not isinstance(data_source, DataSource): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "a Python data source instance of type 'DataSource'", "actual": f"'{type(data_source).__name__}'", @@ -142,7 +142,7 @@ def main(infile: IO, outfile: IO) -> None: schema = _parse_datatype_json_string(schema_json) if not isinstance(schema, StructType): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "an output schema of type 'StructType'", "actual": f"'{type(schema).__name__}'", diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index d028a210b007..8431e9b3e35d 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -19,7 +19,7 @@ import unittest from typing import Callable, Union -from pyspark.errors import PythonException +from pyspark.errors import PythonException, AnalysisException from pyspark.sql.datasource import ( DataSource, DataSourceReader, @@ -154,7 +154,8 @@ def test_data_source_read_output_named_row_with_wrong_schema(self): read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, k=4)]) ) with self.assertRaisesRegex( - PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH" + PythonException, + r"\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\] Return schema mismatch in the result", ): self.spark.read.format("test").load().show() @@ -373,6 +374,47 @@ def test_case_insensitive_dict(self): self.assertEqual(d2["BaR"], 3) self.assertEqual(d2["baz"], 3) + def test_data_source_type_mismatch(self): + class TestDataSource(DataSource): + @classmethod + def name(cls): + return "test" + + def schema(self): + return "id int" + + def reader(self, schema): + return TestReader() + + def writer(self, schema, overwrite): + return TestWriter() + + class TestReader: + def partitions(self): + return [] + + def read(self, partition): + yield (0,) + + class TestWriter: + def write(self, iterator): + return WriterCommitMessage() + + self.spark.dataSource.register(TestDataSource) + + with self.assertRaisesRegex( + AnalysisException, + r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of DataSourceReader", + ): + self.spark.read.format("test").load().show() + + df = self.spark.range(10) + with self.assertRaisesRegex( + AnalysisException, + r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of DataSourceWriter", + ): + df.write.format("test").mode("append").saveAsTable("test_table") + class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase): ... diff --git a/python/pyspark/sql/worker/commit_data_source_write.py b/python/pyspark/sql/worker/commit_data_source_write.py index cf22c19ab3eb..9f3a176ed74d 100644 --- a/python/pyspark/sql/worker/commit_data_source_write.py +++ b/python/pyspark/sql/worker/commit_data_source_write.py @@ -60,14 +60,7 @@ def main(infile: IO, outfile: IO) -> None: # Receive the data source writer instance. writer = pickleSer._read_with_length(infile) - if not isinstance(writer, DataSourceWriter): - raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", - message_parameters={ - "expected": "an instance of DataSourceWriter", - "actual": f"'{type(writer).__name__}'", - }, - ) + assert isinstance(writer, DataSourceWriter) # Receive the commit messages. num_messages = read_int(infile) @@ -76,7 +69,7 @@ def main(infile: IO, outfile: IO) -> None: message = pickleSer._read_with_length(infile) if message is not None and not isinstance(message, WriterCommitMessage): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "an instance of WriterCommitMessage", "actual": f"'{type(message).__name__}'", diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index 33394cdff876..d6b59b04393d 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -75,7 +75,7 @@ def main(infile: IO, outfile: IO) -> None: data_source_cls = read_command(pickleSer, infile) if not (isinstance(data_source_cls, type) and issubclass(data_source_cls, DataSource)): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "a subclass of DataSource", "actual": f"'{type(data_source_cls).__name__}'", @@ -85,7 +85,7 @@ def main(infile: IO, outfile: IO) -> None: # Check the name method is a class method. if not inspect.ismethod(data_source_cls.name): raise PySparkTypeError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "'name()' method to be a classmethod", "actual": f"'{type(data_source_cls.name).__name__}'", @@ -98,7 +98,7 @@ def main(infile: IO, outfile: IO) -> None: # Check if the provider name matches the data source's name. if provider.lower() != data_source_cls.name().lower(): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": f"provider with name {data_source_cls.name()}", "actual": f"'{provider}'", @@ -111,7 +111,7 @@ def main(infile: IO, outfile: IO) -> None: user_specified_schema = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if not isinstance(user_specified_schema, StructType): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "the user-defined schema to be a 'StructType'", "actual": f"'{type(data_source_cls).__name__}'", diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index be7ebd20f180..4ba9f2cdc0d2 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -32,7 +32,7 @@ ) from pyspark.sql import Row from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion -from pyspark.sql.datasource import DataSource, InputPartition +from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition from pyspark.sql.datasource_internal import _streamReader from pyspark.sql.pandas.types import to_arrow_schema from pyspark.sql.types import ( @@ -108,7 +108,7 @@ def batched(iterator: Iterator, n: int) -> Iterator: # Check if the names are the same as the schema. if set(result.__fields__) != col_name_set: raise PySparkRuntimeError( - error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH", + error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH", message_parameters={ "expected": str(column_names), "actual": str(result.__fields__), @@ -187,7 +187,7 @@ def main(infile: IO, outfile: IO) -> None: schema = _parse_datatype_json_string(schema_json) if not isinstance(schema, StructType): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "an output schema of type 'StructType'", "actual": f"'{type(schema).__name__}'", @@ -204,11 +204,19 @@ def main(infile: IO, outfile: IO) -> None: is_streaming = read_bool(infile) # Instantiate data source reader. - reader = ( - _streamReader(data_source, schema) - if is_streaming - else data_source.reader(schema=schema) - ) + if is_streaming: + reader = _streamReader(data_source, schema) + else: + reader = data_source.reader(schema=schema) + # Validate the reader. + if not isinstance(reader, DataSourceReader): + raise PySparkAssertionError( + error_class="DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "an instance of DataSourceReader", + "actual": f"'{type(reader).__name__}'", + }, + ) # Create input converter. converter = ArrowTableToRowsConversion._create_converter(BinaryType()) diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py b/python/pyspark/sql/worker/python_streaming_sink_runner.py index 98a7a22d0a6f..42aa7593f18d 100644 --- a/python/pyspark/sql/worker/python_streaming_sink_runner.py +++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py @@ -70,7 +70,7 @@ def main(infile: IO, outfile: IO) -> None: if not isinstance(data_source, DataSource): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "a Python data source instance of type 'DataSource'", "actual": f"'{type(data_source).__name__}'", @@ -81,7 +81,7 @@ def main(infile: IO, outfile: IO) -> None: schema = _parse_datatype_json_string(schema_json) if not isinstance(schema, StructType): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "an output schema of type 'StructType'", "actual": f"'{type(schema).__name__}'", @@ -101,7 +101,7 @@ def main(infile: IO, outfile: IO) -> None: message = pickleSer._read_with_length(infile) if message is not None and not isinstance(message, WriterCommitMessage): raise PySparkAssertionError( - error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + error_class="DATA_SOURCE_TYPE_MISMATCH", message_parameters={ "expected": "an instance of WriterCommitMessage", "actual": f"'{type(message).__name__}'", diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index 5714f35cbe71..212a2754ec9f 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -29,7 +29,12 @@ SpecialLengths, ) from pyspark.sql import Row -from pyspark.sql.datasource import DataSource, WriterCommitMessage, CaseInsensitiveDict +from pyspark.sql.datasource import ( + DataSource, + DataSourceWriter, + WriterCommitMessage, + CaseInsensitiveDict, +) from pyspark.sql.types import ( _parse_datatype_json_string, StructType, @@ -162,6 +167,14 @@ def main(infile: IO, outfile: IO) -> None: else: # Instantiate the data source writer. writer = data_source.writer(schema, overwrite) # type: ignore[assignment] + if not isinstance(writer, DataSourceWriter): + raise PySparkAssertionError( + error_class="DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "an instance of DataSourceWriter", + "actual": f"'{type(writer).__name__}'", + }, + ) # Create a function that can be used in mapInArrow. import pyarrow as pa From 5ee2efca7ae58977b6d5f11016822488141b6833 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Thu, 20 Jun 2024 11:32:24 -0700 Subject: [PATCH 2/2] fix mypy --- .../sql/worker/plan_data_source_read.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index 4ba9f2cdc0d2..51a90bba1454 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -20,7 +20,7 @@ import functools import pyarrow as pa from itertools import islice -from typing import IO, List, Iterator, Iterable, Tuple +from typing import IO, List, Iterator, Iterable, Tuple, Union from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError, PySparkRuntimeError @@ -32,7 +32,12 @@ ) from pyspark.sql import Row from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion -from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition +from pyspark.sql.datasource import ( + DataSource, + DataSourceReader, + DataSourceStreamReader, + InputPartition, +) from pyspark.sql.datasource_internal import _streamReader from pyspark.sql.pandas.types import to_arrow_schema from pyspark.sql.types import ( @@ -205,7 +210,9 @@ def main(infile: IO, outfile: IO) -> None: # Instantiate data source reader. if is_streaming: - reader = _streamReader(data_source, schema) + reader: Union[DataSourceReader, DataSourceStreamReader] = _streamReader( + data_source, schema + ) else: reader = data_source.reader(schema=schema) # Validate the reader. @@ -249,7 +256,7 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec f"but found '{type(partition).__name__}'." ) - output_iter = reader.read(partition) # type: ignore[attr-defined] + output_iter = reader.read(partition) # type: ignore[arg-type] # Validate the output iterator. if not isinstance(output_iter, Iterator): @@ -272,7 +279,7 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec if not is_streaming: # The partitioning of python batch source read is determined before query execution. try: - partitions = reader.partitions() # type: ignore[attr-defined] + partitions = reader.partitions() # type: ignore[call-arg] if not isinstance(partitions, list): raise PySparkRuntimeError( error_class="DATA_SOURCE_TYPE_MISMATCH", @@ -291,9 +298,9 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec }, ) if len(partitions) == 0: - partitions = [None] + partitions = [None] # type: ignore[list-item] except NotImplementedError: - partitions = [None] + partitions = [None] # type: ignore[list-item] # Return the serialized partition values. write_int(len(partitions), outfile)