Skip to content
Prev Previous commit
Next Next commit
in Connect
  • Loading branch information
xinrong-meng committed Apr 17, 2023
commit d702b67e651ca089280b1d1074b321619049fc45
12 changes: 9 additions & 3 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
LambdaFunction,
UnresolvedNamedLambdaVariable,
)
from pyspark.sql.connect.udf import _create_udf
from pyspark.sql.connect.udf import _create_py_udf
from pyspark.sql import functions as pysparkfuncs
from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType, StringType

Expand Down Expand Up @@ -2461,6 +2461,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column:
def udf(
f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
returnType: "DataTypeOrString" = StringType(),
useArrow: Optional[bool] = None,
) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]:
from pyspark.rdd import PythonEvalType

Expand All @@ -2469,10 +2470,15 @@ def udf(
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(
_create_udf, returnType=return_type, evalType=PythonEvalType.SQL_BATCHED_UDF
_create_py_udf,
returnType=return_type,
evalType=PythonEvalType.SQL_BATCHED_UDF,
useArrow=useArrow,
)
else:
return _create_udf(f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF)
return _create_py_udf(
f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow
)


udf.__doc__ = pysparkfuncs.udf.__doc__
Expand Down
39 changes: 38 additions & 1 deletion python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import sys
import functools
from inspect import getfullargspec
from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union

from pyspark.rdd import PythonEvalType
Expand All @@ -33,7 +34,7 @@
)
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.types import UnparsedDataType
from pyspark.sql.types import DataType, StringType
from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType
from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration


Expand All @@ -47,6 +48,42 @@
from pyspark.sql.types import StringType


def _create_py_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
evalType: int,
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
from pyspark.sql.udf import _create_arrow_py_udf
from pyspark.sql.connect.session import _active_spark_session

if _active_spark_session is None:
is_arrow_enabled = False
else:
is_arrow_enabled = (
_active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") == "true"
if useArrow is None
else useArrow
)

regular_udf = _create_udf(f, returnType, evalType)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is duplicated code in _create_py_udf between Spark Connect Python Client and vanilla PySpark, except for fetching the active SparkSession.
However, for a clear code path separation and abstraction, I decided not to refactor it for now.

return_type = regular_udf.returnType
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
is_output_atomic_type = (
not isinstance(return_type, StructType)
and not isinstance(return_type, MapType)
and not isinstance(return_type, ArrayType)
)
if is_arrow_enabled and is_output_atomic_type and is_func_with_args:
print("entering _create_arrow_py_udf")
return _create_arrow_py_udf(f, regular_udf)
else:
return regular_udf


def _create_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
Expand Down