From de18e89bfcb4596f48a95933b78e95e6e674c133 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Tue, 4 Nov 2025 00:02:10 +0100 Subject: [PATCH 1/8] Initial commit --- python/pyspark/sql/__init__.py | 4 +- .../sql/tests/connect/test_parity_types.py | 4 + python/pyspark/sql/tests/test_types.py | 249 ++++++++++++++++++ python/pyspark/sql/types.py | 190 +++++++++++++ .../spark/sql/catalyst/util/STUtils.java | 4 + .../sql/execution/python/EvaluatePython.scala | 27 +- 6 files changed, 474 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index a0a6e8ef70c8d..eeeeddd00e3af 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -39,7 +39,7 @@ - :class:`pyspark.sql.Window` For working with window functions. """ -from pyspark.sql.types import Row, VariantVal +from pyspark.sql.types import Geography, Geometry, Row, VariantVal from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, UDTFRegistration from pyspark.sql.session import SparkSession from pyspark.sql.column import Column @@ -69,6 +69,8 @@ "DataFrameNaFunctions", "DataFrameStatFunctions", "VariantVal", + "Geography", + "Geometry", "Window", "WindowSpec", "DataFrameReader", diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 6d06611def6af..a39e92233bc0e 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -34,6 +34,10 @@ def test_apply_schema_to_dict_and_rows(self): def test_apply_schema_to_row(self): super().test_apply_schema_to_row() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") + def test_geospatial_create_dataframe_rdd(self): + super().test_geospatial_create_dataframe_rdd() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_create_dataframe_schema_mismatch(self): super().test_create_dataframe_schema_mismatch() diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 6979095acca88..d19d604be97e8 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -29,6 +29,8 @@ from pyspark.sql import functions as F from pyspark.errors import ( AnalysisException, + IllegalArgumentException, + SparkRuntimeException, ParseException, PySparkTypeError, PySparkValueError, @@ -51,6 +53,8 @@ MapType, StringType, CharType, + Geography, + Geometry, VarcharType, StructType, StructField, @@ -1365,6 +1369,7 @@ def test_parse_datatype_json_string(self): NullType(), GeographyType(4326), GeographyType("ANY"), + GeometryType(0), GeometryType(4326), GeometryType("ANY"), VariantType(), @@ -2447,6 +2452,250 @@ def test_variant_type(self): with self.assertRaises(PySparkValueError, msg="Rows cannot be of type VariantVal"): self.spark.createDataFrame([VariantVal.parseJson("2")], "v variant") + def test_geospatial_encoding(self): + df = self.spark.createDataFrame( + [(bytes.fromhex("0101000000000000000000F03F0000000000000040"), 4326,)], + ["wkb", "srid"], + ) + row = df.select( + F.st_geomfromwkb(df.wkb).alias("geom"), + F.st_geogfromwkb(df.wkb).alias("geog"), + ).collect()[0] + + self.assertEqual(type(row["geom"]), Geometry) + self.assertEqual(type(row["geog"]), Geography) + self.assertEqual(row["geom"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040")) + self.assertEqual(row["geom"].getSrid(), 0) + self.assertEqual(row["geog"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040")) + self.assertEqual(row["geog"].getSrid(), 4326) + + def test_geospatial_create_dataframe_rdd(self): + schema = StructType([ + StructField("id", IntegerType(), True), + StructField("geom", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geog", GeographyType(4326), True) + ]) + geospatial_data = [ + (1, + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326)), + (2, + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326)) + ] + rdd_data = self.sc.parallelize(geospatial_data) + df = self.spark.createDataFrame(rdd_data, schema) + rows = df.select( + F.st_asbinary(df.geom).alias("geom_wkb"), + F.st_srid(df.geom).alias("geom_srid"), + F.st_asbinary(df.geom4326).alias("geom4326_wkb"), + F.st_srid(df.geom4326).alias("geom4326_srid"), + F.st_asbinary(df.geog).alias("geog_wkb"), + F.st_srid(df.geog).alias("geog_srid") + ).collect() + + point0_wkb = bytes.fromhex("010100000000000000000031400000000000001c40") + point1_wkb = bytes.fromhex("010100000000000000000014400000000000001440") + self.assertEqual(rows[0]["geom_wkb"], point0_wkb) + self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb) + self.assertEqual(rows[0]["geog_wkb"], point0_wkb) + self.assertEqual(rows[1]["geom_wkb"], point1_wkb) + self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb) + self.assertEqual(rows[1]["geog_wkb"], point1_wkb) + self.assertEqual(rows[0]["geom_srid"], 0) + self.assertEqual(rows[0]["geom4326_srid"], 4326) + self.assertEqual(rows[0]["geog_srid"], 4326) + self.assertEqual(rows[1]["geom_srid"], 0) + self.assertEqual(rows[1]["geom4326_srid"], 4326) + self.assertEqual(rows[1]["geog_srid"], 4326) + schema_df = self.spark.createDataFrame(rdd_data).select( + F.col("_1").alias("id"), + F.col("_2").alias("geom"), + F.col("_3").alias("geom4326"), + F.col("_4").alias("geog") + ) + self.assertEqual(df.collect(), schema_df.collect()) + + + def test_geospatial_create_dataframe(self): + # Positive Test: Creating DataFrame from a list of tuples with explicit schema + geospatial_data = [ + (1, + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326)), + (2, + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326)) + ] + named_geospatial_data = [ + {"id": 1, "geom": Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + "geom4326": Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + "geog": Geography.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326)}, + {"id": 2, "geom": Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + "geom4326": Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + "geog": Geography.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326)} + ] + GeospatialRow = Row("id", "geom", "geom4326", "geog") + spark_row_data = [ + GeospatialRow(1, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326)), + GeospatialRow(2, Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326)) + ] + schema = StructType([ + StructField("id", IntegerType(), True), + StructField("geom", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geog", GeographyType(4326), True) + ]) + # Negative Test: Schema mismatch + mismatched_schema = StructType([ + StructField("id", IntegerType(), True), # Should be GeometryType + StructField("geom", GeometryType(4326), True), # Should be GeometryType + StructField("geom4326", GeometryType(4326), True), # Should be GeometryType + StructField("geog", GeographyType(4326), True) # Should be GeographyType + ]) + + # Explicitly validate single test case + # rest will be compared with this one. + df = self.spark.createDataFrame(geospatial_data, schema) + rows = df.select( + F.st_asbinary(df.geom).alias("geom_wkb"), + F.st_srid(df.geom).alias("geom_srid"), + F.st_asbinary(df.geom4326).alias("geom4326_wkb"), + F.st_srid(df.geom4326).alias("geom4326_srid"), + F.st_asbinary(df.geog).alias("geog_wkb"), + F.st_srid(df.geog).alias("geog_srid") + ).collect() + + point0_wkb = bytes.fromhex("010100000000000000000031400000000000001c40") + point1_wkb = bytes.fromhex("010100000000000000000014400000000000001440") + self.assertEqual(rows[0]["geom_wkb"], point0_wkb) + self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb) + self.assertEqual(rows[0]["geog_wkb"], point0_wkb) + self.assertEqual(rows[1]["geom_wkb"], point1_wkb) + self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb) + self.assertEqual(rows[1]["geog_wkb"], point1_wkb) + self.assertEqual(rows[0]["geom_srid"], 0) + self.assertEqual(rows[0]["geom4326_srid"], 4326) + self.assertEqual(rows[0]["geog_srid"], 4326) + self.assertEqual(rows[1]["geom_srid"], 0) + self.assertEqual(rows[1]["geom4326_srid"], 4326) + self.assertEqual(rows[1]["geog_srid"], 4326) + + # Just the data set without parameters. + self.assertEqual(self.spark.createDataFrame(named_geospatial_data).select("id", "geom", "geom4326", "geog").collect(), df.collect()) + self.assertEqual(self.spark.createDataFrame(geospatial_data).collect(), df.collect()) + self.assertEqual(self.spark.createDataFrame(spark_row_data).collect(), df.collect()) + + # Define DataFrame creation methods + datasets = [ + named_geospatial_data, + geospatial_data, + spark_row_data + ] + schemas = [ + schema, + "id INT, geom GEOMETRY(0), geom4326 GEOMETRY(4326), geog GEOGRAPHY(4326)", + ["id", "geom", "geom4326", "geog"], + ] + + for dataset_to_check, schema_to_check in zip(datasets, schemas): + df_to_check = self.spark.createDataFrame(dataset_to_check, schema_to_check).select("id", "geom", "geom4326", "geog") + self.assertEqual(df_to_check.collect(), df.collect(), "DataFrame creation with schema") + + # Negative Test: Schema mismatch + for dataset_to_check in datasets: + with self.assertRaises(SparkRuntimeException) as pe: + self.spark.createDataFrame(dataset_to_check, mismatched_schema).collect() + + self.check_error( + exception=pe.exception, + errorClass="GEO_ENCODER_SRID_MISMATCH_ERROR", + messageParameters={"type": "GEOMETRY", "typeSrid": "4326", "valueSrid": "0"}, + ) + + def test_geospatial_schema_inferrence(self): + # Mixed data with different SRIDs + wkb = bytes.fromhex("010100000000000000000031400000000000001c40") + geometry_dataset = [ + (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 4326)), + (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 0)), + (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 4326)), + (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 0)), + ] + # Create DataFrame with mixed data types + df = self.spark.createDataFrame(geometry_dataset, ["geom0", "geom4326", "geom_any"]) + expected_schema = StructType([ + StructField("geom0", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geom_any", GeometryType("ANY"), True) + ]) + self.assertEqual(df.schema, expected_schema) + + rows = df.select( + F.st_asbinary("geom0").alias("geom0_wkb"), + F.st_srid("geom0").alias("geom0_srid"), + F.st_asbinary("geom4326").alias("geom4326_wkb"), + F.st_srid("geom4326").alias("geom4326_srid"), + F.st_asbinary("geom_any").alias("geom_any_wkb"), + F.st_srid("geom_any").alias("geom_any_srid"), + ).collect() + + point_wkb = bytes.fromhex("010100000000000000000031400000000000001c40") + self.assertEqual(rows[0]["geom0_wkb"], point_wkb) + self.assertEqual(rows[1]["geom0_wkb"], point_wkb) + self.assertEqual(rows[0]["geom4326_wkb"], point_wkb) + self.assertEqual(rows[1]["geom4326_wkb"], point_wkb) + self.assertEqual(rows[0]["geom_any_wkb"], point_wkb) + self.assertEqual(rows[1]["geom_any_wkb"], point_wkb) + self.assertEqual(rows[0]["geom0_srid"], 0) + self.assertEqual(rows[1]["geom0_srid"], 0) + self.assertEqual(rows[0]["geom4326_srid"], 4326) + self.assertEqual(rows[1]["geom4326_srid"], 4326) + self.assertEqual(rows[0]["geom_any_srid"], 4326) + self.assertEqual(rows[1]["geom_any_srid"], 0) + + def test_geospatial_mixed_check_srid_validity(self): + geometry_mixed_invalid_data = [ + (1, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0)), + (2, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 1)) + ] + + with self.assertRaises(IllegalArgumentException) as pe: + self.spark.createDataFrame(geometry_mixed_invalid_data).collect() + self.check_error( + exception=pe.exception, + errorClass="ST_INVALID_SRID_VALUE", + messageParameters={"srid": "1"}, + ) + + with self.assertRaises(IllegalArgumentException) as pe: + self.spark.createDataFrame(geometry_mixed_invalid_data, f"id INT, geom GEOMETRY(ANY)").collect() + self.check_error( + exception=pe.exception, + errorClass="ST_INVALID_SRID_VALUE", + messageParameters={"srid": "1"}, + ) + + def test_geospatial_result_encoding(self): + point_bytes = bytes.fromhex("010100000000000000000031400000000000001c40") + df = self.spark.sql(""" + SELECT ST_GeomFromWKB(X'010100000000000000000031400000000000001c40', 0) AS geom, + ST_GeomFromWKB(X'010100000000000000000031400000000000001c40', 4326) AS geom4326, + ST_GeogFromWKB(X'010100000000000000000031400000000000001c40') AS geog""") + GeospatialRow = Row("geom", "geom4326", "geog") + self.assertEqual(df.collect(), [GeospatialRow(Geometry.fromWKB(point_bytes, 0), + Geometry.fromWKB(point_bytes, 4326), + Geography.fromWKB(point_bytes, 4326))]) + def test_to_ddl(self): schema = StructType().add("a", NullType()).add("b", BooleanType()).add("c", BinaryType()) self.assertEqual(schema.toDDL(), "a VOID,b BOOLEAN,c BINARY") diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 440100dba9312..52b9b7109d005 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -90,6 +90,8 @@ "TimestampNTZType", "DecimalType", "DoubleType", + "Geography", + "Geometry", "FloatType", "ByteType", "IntegerType", @@ -616,6 +618,20 @@ def jsonValue(self) -> Union[str, Dict[str, Any]]: # The JSON representation always uses the CRS and algorithm value. return f"geography({self._crs}, {self._alg})" + def needConversion(self) -> bool: + return True + + def fromInternal(self, obj: Dict) -> Optional["Geography"]: + if obj is None or not all(key in obj for key in ["srid", "bytes"]): + return None + return Geography(obj["bytes"], obj["srid"]) + + def toInternal(self, geography: Any) -> Any: + if geography is None: + return None + assert isinstance(geography, Geography) + return {"srid": geography.srid, "wkb": geography.wkb} + class GeometryType(SpatialType): """ @@ -2039,6 +2055,148 @@ def parseJson(cls, json_str: str) -> "VariantVal": return VariantVal(value, metadata) +class Geography: + """ + A class to represent a Geography value in Python. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + wkb : bytes + The bytes representing the WKB of Geography. + + srid : integer + The integer value representing SRID of Geography. + + Methods + ------- + getBytes() + Returns the WKB of Geography. + + getSrid() + Returns the SRID of Geography. + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001C40"),)], ["wkb"],) # noqa + >>> g = df.select(sf.geogfromwkb(df.geogwkb).alias("geog")).head().geom # doctest: +SKIP + >>> g.getBytes() # doctest: +SKIP + b'\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x001@\x00\x00\x00\x00\x00\x00\x1c@' + >>> g.getSrid() # doctest: +SKIP + 4326 + """ + + def __init__(self, wkb: bytes, srid: int): + self.wkb = wkb + self.srid = srid + + def __str__(self) -> str: + return "Geography(%r, %d)" % (self.wkb, self.srid) + + def __repr__(self) -> str: + return "Geography(%r, %d)" % (self.wkb, self.srid) + + def getSrid(self) -> int: + """ + Returns the SRID of Geometry. + """ + return self.srid + + def getBytes(self) -> bytes: + """ + Returns the WKB of Geometry. + """ + return self.wkb + + def __eq__(self, other): + if not isinstance(other, Geography): + # Don't attempt to compare against unrelated types. + return NotImplemented + + return self.wkb == other.wkb and self.srid == other.srid + + @classmethod + def fromWKB(cls, wkb: bytes, srid: int) -> "Geography": + """ + Construct Python Geography object from WKB. + :return: Python representation of the Geography type value. + """ + return Geography(wkb, srid) + + +class Geometry: + """ + A class to represent a Geometry value in Python. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + wkb : bytes + The bytes representing the WKB of Geometry. + + srid : integer + The integer value representing SRID of Geometry. + + Methods + ------- + getBytes() + Returns the WKB of Geometry. + + getSrid() + Returns the SRID of Geometry. + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(bytes.fromhex("010100000000000000000031400000000000001C40"),)], ["wkb"],) # noqa + >>> g = df.select(sf.geomfromwkb(df.geomwkb).alias("geom")).head().geom # doctest: +SKIP + >>> g.getBytes() # doctest: +SKIP + b'\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x001@\x00\x00\x00\x00\x00\x00\x1c@' + >>> g.getSrid() # doctest: +SKIP + 0 + """ + + def __init__(self, wkb: bytes, srid: int): + self.wkb = wkb + self.srid = srid + + def __str__(self) -> str: + return "Geometry(%r, %d)" % (self.wkb, self.srid) + + def __repr__(self) -> str: + return "Geometry(%r, %d)" % (self.wkb, self.srid) + + def getSrid(self) -> int: + """ + Returns the SRID of Geometry. + """ + return self.srid + + def getBytes(self) -> bytes: + """ + Returns the WKB of Geometry. + """ + return self.wkb + + def __eq__(self, other): + if not isinstance(other, Geometry): + # Don't attempt to compare against unrelated types. + return NotImplemented + + return self.wkb == other.wkb and self.srid == other.srid + + @classmethod + def fromWKB(cls, wkb: bytes, srid: int) -> "Geometry": + """ + Construct Python Geometry object from WKB. + :return: Python representation of the Geometry type value. + """ + return Geometry(wkb, srid) + + _atomic_types: List[Type[DataType]] = [ StringType, CharType, @@ -2349,6 +2507,8 @@ def _assert_valid_collation_provider(provider: str) -> None: # Mapping Python types to Spark SQL DataType _type_mappings = { type(None): NullType, + Geometry: GeometryType, + Geography: GeographyType, bool: BooleanType, int: LongType, float: DoubleType, @@ -2480,6 +2640,12 @@ def _infer_type( return obj.__UDT__ dataType = _type_mappings.get(type(obj)) + if dataType is GeographyType: + assert isinstance(obj, Geography) + return GeographyType(obj.getSrid()) + if dataType is GeometryType: + assert isinstance(obj, Geometry) + return GeometryType(obj.getSrid()) if dataType is DecimalType: # the precision and scale of `obj` may be different from row to row. return DecimalType(38, 18) @@ -2747,6 +2913,10 @@ def new_name(n: str) -> str: return a elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType): return b + elif isinstance(a, GeometryType) and isinstance(b, GeometryType) and a.srid != b.srid: + return GeometryType("ANY") + elif isinstance(a, GeographyType) and isinstance(b, GeographyType) and a.srid != b.srid: + return GeographyType("ANY") elif isinstance(a, AtomicType) and isinstance(b, StringType): return b elif isinstance(a, StringType) and isinstance(b, AtomicType): @@ -2900,6 +3070,8 @@ def convert_struct(obj: Any) -> Optional[Tuple]: ArrayType: (list, tuple, array), MapType: (dict,), StructType: (tuple, list, dict), + GeometryType: (Geometry,), + GeographyType: (Geography,), VariantType: ( bool, int, @@ -3251,6 +3423,24 @@ def verify_variant(obj: Any) -> None: verify_value = verify_variant + elif isinstance(dataType, GeometryType): + + def verify_geometry(obj: Any) -> None: + assert_acceptable_types(obj) + verify_acceptable_types(obj) + assert isinstance(obj, Geometry) + + verify_value = verify_geometry + + elif isinstance(dataType, GeographyType): + + def verify_geography(obj: Any) -> None: + assert_acceptable_types(obj) + verify_acceptable_types(obj) + assert isinstance(obj, Geography) + + verify_value = verify_geography + else: def verify_default(obj: Any) -> None: diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java index 9edeee26eb98a..0f77bcbbaa47d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java @@ -101,6 +101,10 @@ public static GeometryVal stGeomFromWKB(byte[] wkb) { return toPhysVal(Geometry.fromWkb(wkb)); } + public static GeometryVal stGeomFromWKB(byte[] wkb, int srid) { + return toPhysVal(Geometry.fromWkb(wkb, srid)); + } + // ST_Srid public static int stSrid(GeographyVal geog) { return fromPhysVal(geog).srid(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 212cc5db124ce..33622ca7349a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -29,9 +29,9 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData, STUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{UTF8String, VariantVal} +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String, VariantVal} object EvaluatePython { @@ -43,7 +43,7 @@ object EvaluatePython { def needConversionInPython(dt: DataType): Boolean = dt match { case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType - | _: TimeType => true + | _: TimeType | _: GeometryType | _: GeographyType => true case _: StructType => true case _: UserDefinedType[_] => true case ArrayType(elementType, _) => needConversionInPython(elementType) @@ -92,6 +92,10 @@ object EvaluatePython { case (s: UTF8String, _: StringType) => s.toString + case (g: GeometryVal, gt: GeometryType) => STUtils.deserializeGeom(g, gt) + + case (g: GeographyVal, gt: GeographyType) => STUtils.deserializeGeog(g, gt) + case (bytes: Array[Byte], BinaryType) => if (binaryAsBytes) { new BytesWrapper(bytes) @@ -228,6 +232,23 @@ object EvaluatePython { ) } + case g: GeographyType => (obj: Any) => nullSafeConvert(obj) { + case s: java.util.HashMap[_, _] => + val geographySrid = s.get("srid").asInstanceOf[Int] + g.assertSridAllowedForType(geographySrid) + STUtils.stGeogFromWKB( + s.get("wkb").asInstanceOf[Array[Byte]]) + } + + case g: GeometryType => (obj: Any) => nullSafeConvert(obj) { + case s: java.util.HashMap[_, _] => + val geometrySrid = s.get("srid").asInstanceOf[Int] + g.assertSridAllowedForType(geometrySrid) + STUtils.stGeomFromWKB( + s.get("wkb").asInstanceOf[Array[Byte]], + geometrySrid) + } + case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty) } From 1b8ab1631ddde6ffeebfbf6cc5b9e54ef6942230 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Tue, 4 Nov 2025 07:47:52 +0100 Subject: [PATCH 2/8] Fix Python linter issues --- python/pyspark/sql/tests/test_types.py | 218 +++++++++++++++++-------- 1 file changed, 146 insertions(+), 72 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index d19d604be97e8..4ec4282c4442c 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2454,7 +2454,12 @@ def test_variant_type(self): def test_geospatial_encoding(self): df = self.spark.createDataFrame( - [(bytes.fromhex("0101000000000000000000F03F0000000000000040"), 4326,)], + [ + ( + bytes.fromhex("0101000000000000000000F03F0000000000000040"), + 4326, + ) + ], ["wkb", "srid"], ) row = df.select( @@ -2464,27 +2469,41 @@ def test_geospatial_encoding(self): self.assertEqual(type(row["geom"]), Geometry) self.assertEqual(type(row["geog"]), Geography) - self.assertEqual(row["geom"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040")) + self.assertEqual( + row["geom"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040") + ) self.assertEqual(row["geom"].getSrid(), 0) - self.assertEqual(row["geog"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040")) + self.assertEqual( + row["geog"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040") + ) self.assertEqual(row["geog"].getSrid(), 4326) def test_geospatial_create_dataframe_rdd(self): - schema = StructType([ - StructField("id", IntegerType(), True), - StructField("geom", GeometryType(0), True), - StructField("geom4326", GeometryType(4326), True), - StructField("geog", GeographyType(4326), True) - ]) + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField("geom", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geog", GeographyType(4326), True), + ] + ) geospatial_data = [ - (1, - Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), - Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), - Geography.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326)), - (2, - Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), - Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), - Geography.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326)) + ( + 1, + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + ), + ( + 2, + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + ), ] rdd_data = self.sc.parallelize(geospatial_data) df = self.spark.createDataFrame(rdd_data, schema) @@ -2494,7 +2513,7 @@ def test_geospatial_create_dataframe_rdd(self): F.st_asbinary(df.geom4326).alias("geom4326_wkb"), F.st_srid(df.geom4326).alias("geom4326_srid"), F.st_asbinary(df.geog).alias("geog_wkb"), - F.st_srid(df.geog).alias("geog_srid") + F.st_srid(df.geog).alias("geog_srid"), ).collect() point0_wkb = bytes.fromhex("010100000000000000000031400000000000001c40") @@ -2515,53 +2534,92 @@ def test_geospatial_create_dataframe_rdd(self): F.col("_1").alias("id"), F.col("_2").alias("geom"), F.col("_3").alias("geom4326"), - F.col("_4").alias("geog") + F.col("_4").alias("geog"), ) self.assertEqual(df.collect(), schema_df.collect()) - def test_geospatial_create_dataframe(self): # Positive Test: Creating DataFrame from a list of tuples with explicit schema geospatial_data = [ - (1, - Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), - Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), - Geography.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326)), - (2, - Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), - Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), - Geography.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326)) + ( + 1, + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + ), + ( + 2, + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + ), ] named_geospatial_data = [ - {"id": 1, "geom": Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), - "geom4326": Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), - "geog": Geography.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326)}, - {"id": 2, "geom": Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), - "geom4326": Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), - "geog": Geography.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326)} + { + "id": 1, + "geom": Geometry.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 0 + ), + "geom4326": Geometry.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + "geog": Geography.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + }, + { + "id": 2, + "geom": Geometry.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 0 + ), + "geom4326": Geometry.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + "geog": Geography.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + }, ] GeospatialRow = Row("id", "geom", "geom4326", "geog") spark_row_data = [ - GeospatialRow(1, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), - Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), - Geography.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326)), - GeospatialRow(2, Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), - Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), - Geography.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326)) + GeospatialRow( + 1, + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000031400000000000001c40"), 4326 + ), + ), + GeospatialRow( + 2, + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0), + Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326), + Geography.fromWKB( + bytes.fromhex("010100000000000000000014400000000000001440"), 4326 + ), + ), ] - schema = StructType([ - StructField("id", IntegerType(), True), - StructField("geom", GeometryType(0), True), - StructField("geom4326", GeometryType(4326), True), - StructField("geog", GeographyType(4326), True) - ]) + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField("geom", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geog", GeographyType(4326), True), + ] + ) # Negative Test: Schema mismatch - mismatched_schema = StructType([ - StructField("id", IntegerType(), True), # Should be GeometryType - StructField("geom", GeometryType(4326), True), # Should be GeometryType - StructField("geom4326", GeometryType(4326), True), # Should be GeometryType - StructField("geog", GeographyType(4326), True) # Should be GeographyType - ]) + mismatched_schema = StructType( + [ + StructField("id", IntegerType(), True), # Should be GeometryType + StructField("geom", GeometryType(4326), True), # Should be GeometryType + StructField("geom4326", GeometryType(4326), True), # Should be GeometryType + StructField("geog", GeographyType(4326), True), # Should be GeographyType + ] + ) # Explicitly validate single test case # rest will be compared with this one. @@ -2572,7 +2630,7 @@ def test_geospatial_create_dataframe(self): F.st_asbinary(df.geom4326).alias("geom4326_wkb"), F.st_srid(df.geom4326).alias("geom4326_srid"), F.st_asbinary(df.geog).alias("geog_wkb"), - F.st_srid(df.geog).alias("geog_srid") + F.st_srid(df.geog).alias("geog_srid"), ).collect() point0_wkb = bytes.fromhex("010100000000000000000031400000000000001c40") @@ -2591,16 +2649,17 @@ def test_geospatial_create_dataframe(self): self.assertEqual(rows[1]["geog_srid"], 4326) # Just the data set without parameters. - self.assertEqual(self.spark.createDataFrame(named_geospatial_data).select("id", "geom", "geom4326", "geog").collect(), df.collect()) + self.assertEqual( + self.spark.createDataFrame(named_geospatial_data) + .select("id", "geom", "geom4326", "geog") + .collect(), + df.collect(), + ) self.assertEqual(self.spark.createDataFrame(geospatial_data).collect(), df.collect()) self.assertEqual(self.spark.createDataFrame(spark_row_data).collect(), df.collect()) # Define DataFrame creation methods - datasets = [ - named_geospatial_data, - geospatial_data, - spark_row_data - ] + datasets = [named_geospatial_data, geospatial_data, spark_row_data] schemas = [ schema, "id INT, geom GEOMETRY(0), geom4326 GEOMETRY(4326), geog GEOGRAPHY(4326)", @@ -2608,7 +2667,9 @@ def test_geospatial_create_dataframe(self): ] for dataset_to_check, schema_to_check in zip(datasets, schemas): - df_to_check = self.spark.createDataFrame(dataset_to_check, schema_to_check).select("id", "geom", "geom4326", "geog") + df_to_check = self.spark.createDataFrame(dataset_to_check, schema_to_check).select( + "id", "geom", "geom4326", "geog" + ) self.assertEqual(df_to_check.collect(), df.collect(), "DataFrame creation with schema") # Negative Test: Schema mismatch @@ -2633,11 +2694,13 @@ def test_geospatial_schema_inferrence(self): ] # Create DataFrame with mixed data types df = self.spark.createDataFrame(geometry_dataset, ["geom0", "geom4326", "geom_any"]) - expected_schema = StructType([ - StructField("geom0", GeometryType(0), True), - StructField("geom4326", GeometryType(4326), True), - StructField("geom_any", GeometryType("ANY"), True) - ]) + expected_schema = StructType( + [ + StructField("geom0", GeometryType(0), True), + StructField("geom4326", GeometryType(4326), True), + StructField("geom_any", GeometryType("ANY"), True), + ] + ) self.assertEqual(df.schema, expected_schema) rows = df.select( @@ -2666,7 +2729,7 @@ def test_geospatial_schema_inferrence(self): def test_geospatial_mixed_check_srid_validity(self): geometry_mixed_invalid_data = [ (1, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0)), - (2, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 1)) + (2, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 1)), ] with self.assertRaises(IllegalArgumentException) as pe: @@ -2678,7 +2741,9 @@ def test_geospatial_mixed_check_srid_validity(self): ) with self.assertRaises(IllegalArgumentException) as pe: - self.spark.createDataFrame(geometry_mixed_invalid_data, f"id INT, geom GEOMETRY(ANY)").collect() + self.spark.createDataFrame( + geometry_mixed_invalid_data, f"id INT, geom GEOMETRY(ANY)" + ).collect() self.check_error( exception=pe.exception, errorClass="ST_INVALID_SRID_VALUE", @@ -2687,14 +2752,23 @@ def test_geospatial_mixed_check_srid_validity(self): def test_geospatial_result_encoding(self): point_bytes = bytes.fromhex("010100000000000000000031400000000000001c40") - df = self.spark.sql(""" + df = self.spark.sql( + """ SELECT ST_GeomFromWKB(X'010100000000000000000031400000000000001c40', 0) AS geom, ST_GeomFromWKB(X'010100000000000000000031400000000000001c40', 4326) AS geom4326, - ST_GeogFromWKB(X'010100000000000000000031400000000000001c40') AS geog""") + ST_GeogFromWKB(X'010100000000000000000031400000000000001c40') AS geog""" + ) GeospatialRow = Row("geom", "geom4326", "geog") - self.assertEqual(df.collect(), [GeospatialRow(Geometry.fromWKB(point_bytes, 0), - Geometry.fromWKB(point_bytes, 4326), - Geography.fromWKB(point_bytes, 4326))]) + self.assertEqual( + df.collect(), + [ + GeospatialRow( + Geometry.fromWKB(point_bytes, 0), + Geometry.fromWKB(point_bytes, 4326), + Geography.fromWKB(point_bytes, 4326), + ) + ], + ) def test_to_ddl(self): schema = StructType().add("a", NullType()).add("b", BooleanType()).add("c", BinaryType()) From 7d2070d2c6bb68ae9c5f85b02ad61751f6b58231 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Tue, 4 Nov 2025 15:51:17 +0100 Subject: [PATCH 3/8] Fix flake8 issue --- python/pyspark/sql/tests/test_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 4ec4282c4442c..4248a37197af6 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2742,7 +2742,7 @@ def test_geospatial_mixed_check_srid_validity(self): with self.assertRaises(IllegalArgumentException) as pe: self.spark.createDataFrame( - geometry_mixed_invalid_data, f"id INT, geom GEOMETRY(ANY)" + geometry_mixed_invalid_data, "id INT, geom GEOMETRY(ANY)" ).collect() self.check_error( exception=pe.exception, From 35af36a6942f0749f12ce7a00d7a41de18ecd236 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Tue, 4 Nov 2025 20:41:39 +0100 Subject: [PATCH 4/8] Fix mypy issues --- python/pyspark/sql/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 52b9b7109d005..89bf5d7797b39 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -2110,7 +2110,7 @@ def getBytes(self) -> bytes: """ return self.wkb - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, Geography): # Don't attempt to compare against unrelated types. return NotImplemented @@ -2181,7 +2181,7 @@ def getBytes(self) -> bytes: """ return self.wkb - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, Geometry): # Don't attempt to compare against unrelated types. return NotImplemented From 7adc1e30b16c63b5d570cdaca66f3ac0b348a17c Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Wed, 5 Nov 2025 22:42:55 +0100 Subject: [PATCH 5/8] Update Pandas --- python/pyspark/sql/pandas/types.py | 60 ++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 327e3941d9386..581c9bdd14e98 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -50,6 +50,10 @@ UserDefinedType, VariantType, VariantVal, + GeometryType, + Geometry, + GeographyType, + Geography, _create_row, ) from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError @@ -202,6 +206,18 @@ def to_arrow_type( pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}), ] arrow_type = pa.struct(fields) + elif type(dt) == GeometryType: + fields = [ + pa.field("srid", pa.int32(), nullable=False), + pa.field("wkb", pa.binary(), nullable=False, metadata={b"geometry": b"true", b"srid": str(dt.srid)}), + ] + arrow_type = pa.struct(fields) + elif type(dt) == GeographyType: + fields = [ + pa.field("srid", pa.int32(), nullable=False), + pa.field("wkb", pa.binary(), nullable=False, metadata={b"geography": b"true", b"srid": str(dt.srid)}), + ] + arrow_type = pa.struct(fields) else: raise PySparkTypeError( errorClass="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", @@ -272,6 +288,38 @@ def is_variant(at: "pa.DataType") -> bool: ) and any(field.name == "value" for field in at) +def is_geometry(at: "pa.DataType") -> bool: + """Check if a PyArrow struct data type represents a geometry""" + import pyarrow.types as types + + assert types.is_struct(at) + + return any( + ( + field.name == "wkb" + and b"geometry" in field.metadata + and field.metadata[b"geometry"] == b"true" + ) + for field in at + ) and any(field.name == "srid" for field in at) + + +def is_geography(at: "pa.DataType") -> bool: + """Check if a PyArrow struct data type represents a geography""" + import pyarrow.types as types + + assert types.is_struct(at) + + return any( + ( + field.name == "wkb" + and b"geography" in field.metadata + and field.metadata[b"geography"] == b"true" + ) + for field in at + ) and any(field.name == "srid" for field in at) + + def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType: """Convert pyarrow type to Spark data type.""" import pyarrow.types as types @@ -337,6 +385,18 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da elif types.is_struct(at): if is_variant(at): return VariantType() + elif is_geometry(at): + srid = int(at.field("wkb").metadata.get(b"srid")) + if srid == GeometryType.MIXED_SRID: + return GeometryType("ANY") + else: + return GeometryType(srid) + elif is_geography(at): + srid = int(at.field("wkb").metadata.get(b"srid")) + if srid == GeographyType.MIXED_SRID: + return GeographyType("ANY") + else: + return GeographyType(srid) return StructType( [ StructField( From e7823c6c78bccee05b05eb5554e38d796f8bc3b6 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Thu, 6 Nov 2025 03:04:57 +0100 Subject: [PATCH 6/8] Fix Python format issues --- python/pyspark/sql/pandas/types.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 581c9bdd14e98..e600e7426d6b8 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -209,13 +209,23 @@ def to_arrow_type( elif type(dt) == GeometryType: fields = [ pa.field("srid", pa.int32(), nullable=False), - pa.field("wkb", pa.binary(), nullable=False, metadata={b"geometry": b"true", b"srid": str(dt.srid)}), + pa.field( + "wkb", + pa.binary(), + nullable=False, + metadata={b"geometry": b"true", b"srid": str(dt.srid)}, + ), ] arrow_type = pa.struct(fields) elif type(dt) == GeographyType: fields = [ pa.field("srid", pa.int32(), nullable=False), - pa.field("wkb", pa.binary(), nullable=False, metadata={b"geography": b"true", b"srid": str(dt.srid)}), + pa.field( + "wkb", + pa.binary(), + nullable=False, + metadata={b"geography": b"true", b"srid": str(dt.srid)}, + ), ] arrow_type = pa.struct(fields) else: From 9b0381c5e0ac16051662916649b8d29e54bc9ed4 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Thu, 6 Nov 2025 11:30:44 +0100 Subject: [PATCH 7/8] Fixes --- python/pyspark/sql/pandas/types.py | 50 ++++++++++++++++++++++++++ python/pyspark/sql/tests/test_types.py | 15 ++++---- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index e600e7426d6b8..d8a45daa77e89 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -1168,6 +1168,40 @@ def convert_variant(value: Any) -> Any: return convert_variant + elif isinstance(dt, GeographyType): + + def convert_geography(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["wkb", "srid"]) + and isinstance(value["wkb"], bytes) + and isinstance(value["srid"], int) + ): + return Geography.fromWKB(value["wkb"], value["srid"]) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY") + + return convert_geography + + elif isinstance(dt, GeometryType): + + def convert_geometry(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["wkb", "srid"]) + and isinstance(value["wkb"], bytes) + and isinstance(value["srid"], int) + ): + return Geometry.fromWKB(value["wkb"], value["srid"]) + else: + raise PySparkValueError(errorClass="MALFORMED_GEOMETRY") + + return convert_geometry + else: return None @@ -1430,6 +1464,22 @@ def convert_variant(variant: Any) -> Any: return convert_variant + elif isinstance(dt, GeographyType): + + def convert_geography(value: Any) -> Any: + assert isinstance(value, Geography) + return {"srid": value.srid, "wkb": value.wkb} + + return convert_geography + + elif isinstance(dt, GeometryType): + + def convert_geometry(value: Any) -> Any: + assert isinstance(value, Geometry) + return {"srid": value.srid, "wkb": value.wkb} + + return convert_geometry + return None conv = _converter(data_type) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 4248a37197af6..6b82b4ae9a6cb 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2751,20 +2751,17 @@ def test_geospatial_mixed_check_srid_validity(self): ) def test_geospatial_result_encoding(self): - point_bytes = bytes.fromhex("010100000000000000000031400000000000001c40") - df = self.spark.sql( - """ - SELECT ST_GeomFromWKB(X'010100000000000000000031400000000000001c40', 0) AS geom, - ST_GeomFromWKB(X'010100000000000000000031400000000000001c40', 4326) AS geom4326, - ST_GeogFromWKB(X'010100000000000000000031400000000000001c40') AS geog""" - ) - GeospatialRow = Row("geom", "geom4326", "geog") + point_wkb = "010100000000000000000031400000000000001c40" + point_bytes = bytes.fromhex(point_wkb) + df = self.spark.sql(f""" + SELECT ST_GeomFromWKB(X'{point_wkb}') AS geom, + ST_GeogFromWKB(X'{point_wkb}') AS geog""") + GeospatialRow = Row("geom", "geog") self.assertEqual( df.collect(), [ GeospatialRow( Geometry.fromWKB(point_bytes, 0), - Geometry.fromWKB(point_bytes, 4326), Geography.fromWKB(point_bytes, 4326), ) ], From 8dd8baf38bf57dc4cba24d2d91c0ed81f655127c Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Fri, 7 Nov 2025 07:10:05 +0100 Subject: [PATCH 8/8] Fix Python linter --- python/pyspark/sql/tests/test_types.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 6b82b4ae9a6cb..4ff2ab3e5cd73 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2753,9 +2753,11 @@ def test_geospatial_mixed_check_srid_validity(self): def test_geospatial_result_encoding(self): point_wkb = "010100000000000000000031400000000000001c40" point_bytes = bytes.fromhex(point_wkb) - df = self.spark.sql(f""" + df = self.spark.sql( + f""" SELECT ST_GeomFromWKB(X'{point_wkb}') AS geom, - ST_GeogFromWKB(X'{point_wkb}') AS geog""") + ST_GeogFromWKB(X'{point_wkb}') AS geog""" + ) GeospatialRow = Row("geom", "geog") self.assertEqual( df.collect(),