diff --git a/assembly/pom.xml b/assembly/pom.xml
index 4f80719afd4c..69952a7ccfa0 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -166,6 +166,12 @@
${project.version}
provided
+
+ org.apache.spark
+ spark-protobuf_${scala.binary.version}
+ ${project.version}
+ provided
+
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index c5be1957a7dc..4e432e3eba98 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -271,27 +271,27 @@ def __hash__(self):
],
)
-connect = Module(
- name="connect",
- dependencies=[hive, avro],
+protobuf = Module(
+ name="protobuf",
+ dependencies=[sql],
source_file_regexes=[
- "connector/connect",
+ "connector/protobuf",
],
- build_profile_flags=["-Pconnect"],
sbt_test_goals=[
- "connect/test",
- "connect-client-jvm/test",
+ "protobuf/test",
],
)
-protobuf = Module(
- name="protobuf",
- dependencies=[sql],
+connect = Module(
+ name="connect",
+ dependencies=[hive, avro, protobuf],
source_file_regexes=[
- "connector/protobuf",
+ "connector/connect",
],
+ build_profile_flags=["-Pconnect"],
sbt_test_goals=[
- "protobuf/test",
+ "connect/test",
+ "connect-client-jvm/test",
],
)
@@ -832,6 +832,7 @@ def __hash__(self):
"pyspark.sql.connect.dataframe",
"pyspark.sql.connect.functions",
"pyspark.sql.connect.avro.functions",
+ "pyspark.sql.connect.protobuf.functions",
"pyspark.sql.connect.streaming.readwriter",
"pyspark.sql.connect.streaming.query",
# sql unittests
diff --git a/python/pyspark/sql/connect/protobuf/__init__.py b/python/pyspark/sql/connect/protobuf/__init__.py
new file mode 100644
index 000000000000..dc81e9f515ee
--- /dev/null
+++ b/python/pyspark/sql/connect/protobuf/__init__.py
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Spark Connect Python Client - Protobuf Functions"""
diff --git a/python/pyspark/sql/connect/protobuf/functions.py b/python/pyspark/sql/connect/protobuf/functions.py
new file mode 100644
index 000000000000..56119f4bc4eb
--- /dev/null
+++ b/python/pyspark/sql/connect/protobuf/functions.py
@@ -0,0 +1,166 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A collections of builtin protobuf functions
+"""
+
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
+from typing import Dict, Optional, TYPE_CHECKING
+
+from pyspark.sql.protobuf import functions as PyProtobufFunctions
+
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.functions import _invoke_function, _to_col, _options_to_col, lit
+
+if TYPE_CHECKING:
+ from pyspark.sql.connect._typing import ColumnOrName
+
+
+def from_protobuf(
+ data: "ColumnOrName",
+ messageName: str,
+ descFilePath: Optional[str] = None,
+ options: Optional[Dict[str, str]] = None,
+ binaryDescriptorSet: Optional[bytes] = None,
+) -> Column:
+ binary_proto = None
+ if binaryDescriptorSet is not None:
+ binary_proto = binaryDescriptorSet
+ elif descFilePath is not None:
+ binary_proto = _read_descriptor_set_file(descFilePath)
+
+ # TODO: simplify the code when _invoke_function() supports None as input.
+ if binary_proto is not None:
+ if options is None:
+ return _invoke_function(
+ "from_protobuf", _to_col(data), lit(messageName), lit(binary_proto)
+ )
+ else:
+ return _invoke_function(
+ "from_protobuf",
+ _to_col(data),
+ lit(messageName),
+ lit(binary_proto),
+ _options_to_col(options),
+ )
+ else:
+ if options is None:
+ return _invoke_function("from_protobuf", _to_col(data), lit(messageName))
+ else:
+ return _invoke_function(
+ "from_protobuf", _to_col(data), lit(messageName), _options_to_col(options)
+ )
+
+
+from_protobuf.__doc__ = PyProtobufFunctions.from_protobuf.__doc__
+
+
+def to_protobuf(
+ data: "ColumnOrName",
+ messageName: str,
+ descFilePath: Optional[str] = None,
+ options: Optional[Dict[str, str]] = None,
+ binaryDescriptorSet: Optional[bytes] = None,
+) -> Column:
+ binary_proto = None
+ if binaryDescriptorSet is not None:
+ binary_proto = binaryDescriptorSet
+ elif descFilePath is not None:
+ binary_proto = _read_descriptor_set_file(descFilePath)
+
+ # TODO: simplify the code when _invoke_function() supports None as input.
+ if binary_proto is not None:
+ if options is None:
+ return _invoke_function(
+ "to_protobuf", _to_col(data), lit(messageName), lit(binary_proto)
+ )
+ else:
+ return _invoke_function(
+ "to_protobuf",
+ _to_col(data),
+ lit(messageName),
+ lit(binary_proto),
+ _options_to_col(options),
+ )
+ else:
+ if options is None:
+ return _invoke_function("to_protobuf", _to_col(data), lit(messageName))
+ else:
+ return _invoke_function(
+ "to_protobuf", _to_col(data), lit(messageName), _options_to_col(options)
+ )
+
+
+to_protobuf.__doc__ = PyProtobufFunctions.to_protobuf.__doc__
+
+
+def _read_descriptor_set_file(filePath: str) -> bytes:
+ with open(filePath, "rb") as f:
+ return f.read()
+
+
+def _test() -> None:
+ import os
+ import sys
+ from pyspark.testing.utils import search_jar
+
+ protobuf_jar = search_jar("connector/protobuf", "spark-protobuf-assembly-", "spark-protobuf")
+ if protobuf_jar is None:
+ print(
+ "Skipping all Protobuf Python tests as the optional Protobuf project was "
+ "not compiled into a JAR. To run these tests, "
+ "you need to build Spark with 'build/sbt package' or "
+ "'build/mvn package' before running this test."
+ )
+ sys.exit(0)
+ else:
+ existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+ jars_args = "--jars %s" % protobuf_jar
+ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])
+
+ import doctest
+ from pyspark.sql import SparkSession as PySparkSession
+ import pyspark.sql.connect.protobuf.functions
+
+ globs = pyspark.sql.connect.protobuf.functions.__dict__.copy()
+
+ globs["spark"] = (
+ PySparkSession.builder.appName("sql.protobuf.functions tests")
+ .remote("local[2]")
+ .getOrCreate()
+ )
+
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.connect.protobuf.functions,
+ globs=globs,
+ optionflags=doctest.ELLIPSIS
+ | doctest.NORMALIZE_WHITESPACE
+ | doctest.IGNORE_EXCEPTION_DETAIL,
+ )
+
+ globs["spark"].stop()
+
+ if failure_count:
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py
index 42165938eb73..acb1a17efbd6 100644
--- a/python/pyspark/sql/protobuf/functions.py
+++ b/python/pyspark/sql/protobuf/functions.py
@@ -25,13 +25,14 @@
from py4j.java_gateway import JVMView
from pyspark.sql.column import Column, _to_java_column
-from pyspark.sql.utils import get_active_spark_context
+from pyspark.sql.utils import get_active_spark_context, try_remote_protobuf_functions
from pyspark.util import _print_missing_jar
if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName
+@try_remote_protobuf_functions
def from_protobuf(
data: "ColumnOrName",
messageName: str,
@@ -58,6 +59,7 @@ def from_protobuf(
.. versionchanged:: 3.5.0
Supports `binaryDescriptorSet` arg to pass binary descriptor directly.
+ Supports Spark Connect.
Parameters
----------
@@ -161,6 +163,7 @@ def from_protobuf(
return Column(jc)
+@try_remote_protobuf_functions
def to_protobuf(
data: "ColumnOrName",
messageName: str,
@@ -187,6 +190,7 @@ def to_protobuf(
.. versionchanged:: 3.5.0
Supports `binaryDescriptorSet` arg to pass binary descriptor directly.
+ Supports Spark Connect.
Parameters
----------
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index cb262a14cbe2..45df4433916c 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -211,6 +211,22 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
return cast(FuncT, wrapped)
+def try_remote_protobuf_functions(f: FuncT) -> FuncT:
+ """Mark API supported from Spark Connect."""
+
+ @functools.wraps(f)
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
+
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ from pyspark.sql.connect.protobuf import functions
+
+ return getattr(functions, f.__name__)(*args, **kwargs)
+ else:
+ return f(*args, **kwargs)
+
+ return cast(FuncT, wrapped)
+
+
def try_remote_window(f: FuncT) -> FuncT:
"""Mark API supported from Spark Connect."""