diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 9351dbdfd25b4..be8ffacfa3d7b 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -415,8 +415,8 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF, - PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF, - PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, None, @@ -459,8 +459,8 @@ def _validate_pandas_udf(f, evalType) -> int: PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF, - PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF, - PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 5b2ee83beec0b..5fe711f742ce6 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -654,10 +654,10 @@ def __transformWithState( elif usePandas and initialState is not None: functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF elif not usePandas and initialState is None: - functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF + functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF else: # not usePandas and initialState is not None - functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF + functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF if initialState is None: initial_state_java_obj = None diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 605f4f070b5a5..7e07d95538e4a 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -641,8 +641,8 @@ class PythonEvalType: SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: "PandasGroupedMapUDFTransformWithStateInitStateType" = ( # noqa: E501 212 ) - SQL_TRANSFORM_WITH_STATE_UDF: "GroupedMapUDFTransformWithStateType" = 213 - SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF: "GroupedMapUDFTransformWithStateInitStateType" = ( # noqa: E501 + SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: "GroupedMapUDFTransformWithStateType" = 213 + SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: "GroupedMapUDFTransformWithStateInitStateType" = ( # noqa: E501 214 ) SQL_TABLE_UDF: "SQLTableUDFType" = 300 diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 61c006b1a9cf2..5f4408851c671 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -969,9 +969,9 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil return args_offsets, wrap_grouped_transform_with_state_pandas_init_state_udf( func, return_type, runner_conf ) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF: + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: return args_offsets, wrap_grouped_transform_with_state_udf(func, return_type, runner_conf) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF: + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: return args_offsets, wrap_grouped_transform_with_state_init_state_udf( func, return_type, runner_conf ) @@ -1572,8 +1572,8 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF, - PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF, - PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, ): # Load conf used for pandas_udf evaluation num_conf = read_int(infile) @@ -1588,8 +1588,8 @@ def read_udfs(pickleSer, infile, eval_type): elif ( eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF + or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF + or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF ): state_server_port = read_int(infile) if state_server_port == -1: @@ -1641,14 +1641,14 @@ def read_udfs(pickleSer, infile, eval_type): ser = TransformWithStateInPandasInitStateSerializer( timezone, safecheck, _assign_cols_by_name, arrow_max_records_per_batch ) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF: + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: arrow_max_records_per_batch = runner_conf.get( "spark.sql.execution.arrow.maxRecordsPerBatch", 10000 ) arrow_max_records_per_batch = int(arrow_max_records_per_batch) ser = TransformWithStateInPySparkRowSerializer(arrow_max_records_per_batch) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF: + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: arrow_max_records_per_batch = runner_conf.get( "spark.sql.execution.arrow.maxRecordsPerBatch", 10000 ) @@ -1889,7 +1889,7 @@ def values_gen(): # mode == PROCESS_TIMER or mode == COMPLETE return f(stateful_processor_api_client, mode, None, iter([])) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF: + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 @@ -1916,7 +1916,7 @@ def mapper(a): # mode == PROCESS_TIMER or mode == COMPLETE return f(stateful_processor_api_client, mode, None, iter([])) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF: + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1