Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def main(infile: IO, outfile: IO) -> None:

if not isinstance(data_source, DataSource):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a Python data source instance of type 'DataSource'",
"actual": f"'{type(data_source).__name__}'",
Expand All @@ -142,7 +142,7 @@ def main(infile: IO, outfile: IO) -> None:
schema = _parse_datatype_json_string(schema_json)
if not isinstance(schema, StructType):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an output schema of type 'StructType'",
"actual": f"'{type(schema).__name__}'",
Expand Down
46 changes: 44 additions & 2 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest
from typing import Callable, Union

from pyspark.errors import PythonException
from pyspark.errors import PythonException, AnalysisException
from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
Expand Down Expand Up @@ -154,7 +154,8 @@ def test_data_source_read_output_named_row_with_wrong_schema(self):
read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, k=4)])
)
with self.assertRaisesRegex(
PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
PythonException,
r"\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\] Return schema mismatch in the result",
):
self.spark.read.format("test").load().show()

Expand Down Expand Up @@ -373,6 +374,47 @@ def test_case_insensitive_dict(self):
self.assertEqual(d2["BaR"], 3)
self.assertEqual(d2["baz"], 3)

def test_data_source_type_mismatch(self):
class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return "id int"

def reader(self, schema):
return TestReader()

def writer(self, schema, overwrite):
return TestWriter()

class TestReader:
def partitions(self):
return []

def read(self, partition):
yield (0,)

class TestWriter:
def write(self, iterator):
return WriterCommitMessage()

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(
AnalysisException,
r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of DataSourceReader",
):
self.spark.read.format("test").load().show()

df = self.spark.range(10)
with self.assertRaisesRegex(
AnalysisException,
r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of DataSourceWriter",
):
df.write.format("test").mode("append").saveAsTable("test_table")


class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
...
Expand Down
11 changes: 2 additions & 9 deletions python/pyspark/sql/worker/commit_data_source_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,7 @@ def main(infile: IO, outfile: IO) -> None:

# Receive the data source writer instance.
writer = pickleSer._read_with_length(infile)
if not isinstance(writer, DataSourceWriter):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of DataSourceWriter",
"actual": f"'{type(writer).__name__}'",
},
)
assert isinstance(writer, DataSourceWriter)

# Receive the commit messages.
num_messages = read_int(infile)
Expand All @@ -76,7 +69,7 @@ def main(infile: IO, outfile: IO) -> None:
message = pickleSer._read_with_length(infile)
if message is not None and not isinstance(message, WriterCommitMessage):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of WriterCommitMessage",
"actual": f"'{type(message).__name__}'",
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(infile: IO, outfile: IO) -> None:
data_source_cls = read_command(pickleSer, infile)
if not (isinstance(data_source_cls, type) and issubclass(data_source_cls, DataSource)):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a subclass of DataSource",
"actual": f"'{type(data_source_cls).__name__}'",
Expand All @@ -85,7 +85,7 @@ def main(infile: IO, outfile: IO) -> None:
# Check the name method is a class method.
if not inspect.ismethod(data_source_cls.name):
raise PySparkTypeError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "'name()' method to be a classmethod",
"actual": f"'{type(data_source_cls.name).__name__}'",
Expand All @@ -98,7 +98,7 @@ def main(infile: IO, outfile: IO) -> None:
# Check if the provider name matches the data source's name.
if provider.lower() != data_source_cls.name().lower():
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": f"provider with name {data_source_cls.name()}",
"actual": f"'{provider}'",
Expand All @@ -111,7 +111,7 @@ def main(infile: IO, outfile: IO) -> None:
user_specified_schema = _parse_datatype_json_string(utf8_deserializer.loads(infile))
if not isinstance(user_specified_schema, StructType):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "the user-defined schema to be a 'StructType'",
"actual": f"'{type(data_source_cls).__name__}'",
Expand Down
41 changes: 28 additions & 13 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import functools
import pyarrow as pa
from itertools import islice
from typing import IO, List, Iterator, Iterable, Tuple
from typing import IO, List, Iterator, Iterable, Tuple, Union

from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
Expand All @@ -32,7 +32,12 @@
)
from pyspark.sql import Row
from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion
from pyspark.sql.datasource import DataSource, InputPartition
from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
DataSourceStreamReader,
InputPartition,
)
from pyspark.sql.datasource_internal import _streamReader
from pyspark.sql.pandas.types import to_arrow_schema
from pyspark.sql.types import (
Expand Down Expand Up @@ -108,7 +113,7 @@ def batched(iterator: Iterator, n: int) -> Iterator:
# Check if the names are the same as the schema.
if set(result.__fields__) != col_name_set:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
message_parameters={
"expected": str(column_names),
"actual": str(result.__fields__),
Expand Down Expand Up @@ -187,7 +192,7 @@ def main(infile: IO, outfile: IO) -> None:
schema = _parse_datatype_json_string(schema_json)
if not isinstance(schema, StructType):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an output schema of type 'StructType'",
"actual": f"'{type(schema).__name__}'",
Expand All @@ -204,11 +209,21 @@ def main(infile: IO, outfile: IO) -> None:
is_streaming = read_bool(infile)

# Instantiate data source reader.
reader = (
_streamReader(data_source, schema)
if is_streaming
else data_source.reader(schema=schema)
)
if is_streaming:
reader: Union[DataSourceReader, DataSourceStreamReader] = _streamReader(
data_source, schema
)
else:
reader = data_source.reader(schema=schema)
# Validate the reader.
if not isinstance(reader, DataSourceReader):
raise PySparkAssertionError(
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of DataSourceReader",
"actual": f"'{type(reader).__name__}'",
},
)

# Create input converter.
converter = ArrowTableToRowsConversion._create_converter(BinaryType())
Expand Down Expand Up @@ -241,7 +256,7 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
f"but found '{type(partition).__name__}'."
)

output_iter = reader.read(partition) # type: ignore[attr-defined]
output_iter = reader.read(partition) # type: ignore[arg-type]

# Validate the output iterator.
if not isinstance(output_iter, Iterator):
Expand All @@ -264,7 +279,7 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
if not is_streaming:
# The partitioning of python batch source read is determined before query execution.
try:
partitions = reader.partitions() # type: ignore[attr-defined]
partitions = reader.partitions() # type: ignore[call-arg]
if not isinstance(partitions, list):
raise PySparkRuntimeError(
error_class="DATA_SOURCE_TYPE_MISMATCH",
Expand All @@ -283,9 +298,9 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
},
)
if len(partitions) == 0:
partitions = [None]
partitions = [None] # type: ignore[list-item]
except NotImplementedError:
partitions = [None]
partitions = [None] # type: ignore[list-item]

# Return the serialized partition values.
write_int(len(partitions), outfile)
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/worker/python_streaming_sink_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(infile: IO, outfile: IO) -> None:

if not isinstance(data_source, DataSource):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a Python data source instance of type 'DataSource'",
"actual": f"'{type(data_source).__name__}'",
Expand All @@ -81,7 +81,7 @@ def main(infile: IO, outfile: IO) -> None:
schema = _parse_datatype_json_string(schema_json)
if not isinstance(schema, StructType):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an output schema of type 'StructType'",
"actual": f"'{type(schema).__name__}'",
Expand All @@ -101,7 +101,7 @@ def main(infile: IO, outfile: IO) -> None:
message = pickleSer._read_with_length(infile)
if message is not None and not isinstance(message, WriterCommitMessage):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of WriterCommitMessage",
"actual": f"'{type(message).__name__}'",
Expand Down
15 changes: 14 additions & 1 deletion python/pyspark/sql/worker/write_into_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
SpecialLengths,
)
from pyspark.sql import Row
from pyspark.sql.datasource import DataSource, WriterCommitMessage, CaseInsensitiveDict
from pyspark.sql.datasource import (
DataSource,
DataSourceWriter,
WriterCommitMessage,
CaseInsensitiveDict,
)
from pyspark.sql.types import (
_parse_datatype_json_string,
StructType,
Expand Down Expand Up @@ -162,6 +167,14 @@ def main(infile: IO, outfile: IO) -> None:
else:
# Instantiate the data source writer.
writer = data_source.writer(schema, overwrite) # type: ignore[assignment]
if not isinstance(writer, DataSourceWriter):
raise PySparkAssertionError(
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of DataSourceWriter",
"actual": f"'{type(writer).__name__}'",
},
)

# Create a function that can be used in mapInArrow.
import pyarrow as pa
Expand Down