Skip to content
Prev Previous commit
Next Next commit
rmv f from _create_arrow_py_udf
  • Loading branch information
xinrong-meng committed Apr 20, 2023
commit f5aef182ef108f22f138a5c19690c4b10c98551d
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _create_py_udf(
and not isinstance(return_type, ArrayType)
)
if is_arrow_enabled and is_output_atomic_type and is_func_with_args:
return _create_arrow_py_udf(f, regular_udf)
return _create_arrow_py_udf(regular_udf)
else:
return regular_udf

Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,20 @@ def _create_py_udf(
and not isinstance(return_type, ArrayType)
)
if is_arrow_enabled and is_output_atomic_type and is_func_with_args:
return _create_arrow_py_udf(f, regular_udf)
return _create_arrow_py_udf(regular_udf)
else:
return regular_udf


def _create_arrow_py_udf(f, regular_udf): # type: ignore
def _create_arrow_py_udf(regular_udf): # type: ignore
"""Create an Arrow-optimized Python UDF out of a regular Python UDF."""
require_minimum_pandas_version()
require_minimum_pyarrow_version()

import pandas as pd
from pyspark.sql.pandas.functions import _create_pandas_udf

f = regular_udf.func
return_type = regular_udf.returnType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that the regular_udf is only used to pass the returnType and evalType ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And regular_udf.func based on the updated code.


# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
Expand Down