-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
Changes from 13 commits
b1175e4
0a98ed8
8e2b193
16e4c17
92ef716
c3eaf38
609d94e
a27f9d9
684939b
7f65fbd
c25d7da
9c8c616
c641192
8d3da4e
cc9bf95
f7df2dc
27cd169
3b5b3e5
5d910d8
df859ab
654f2f6
38832a6
0585ac0
0ee5029
6232c81
41f8234
d57633f
df9ea9e
68f7a7e
ca5216b
c9e3a7c
2320805
6e5de2e
200ec5e
dd3e46b
e8360d4
82983af
49dbc16
d4e04ea
e108f60
bae26c2
d96fa9e
92531db
d507793
5dcb4c8
37be02a
f63687f
263c087
c7b0a4f
c80b292
81276f3
5886b5c
23e54b4
2ba4fd0
2a9c20b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,14 +15,17 @@ | |
| # limitations under the License. | ||
| # | ||
| import sys | ||
| from typing import List, Union, TYPE_CHECKING, cast | ||
| from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast | ||
| import warnings | ||
|
|
||
| from pyspark.errors import PySparkTypeError | ||
| from pyspark.util import PythonEvalType | ||
| from pyspark.sql.column import Column | ||
| from pyspark.sql.dataframe import DataFrame | ||
| from pyspark.sql.functions.builtin import udf | ||
| from pyspark.sql.streaming.state import GroupStateTimeout | ||
| from pyspark.sql.streaming.state_api_client import StateApiClient, StatefulProcessorHandleState | ||
| from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle | ||
| from pyspark.sql.types import StructType, _parse_datatype_string | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -33,6 +36,7 @@ | |
| PandasCogroupedMapFunction, | ||
| ArrowGroupedMapFunction, | ||
| ArrowCogroupedMapFunction, | ||
| DataFrameLike as PandasDataFrameLike | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| from pyspark.sql.group import GroupedData | ||
|
|
||
|
|
@@ -358,6 +362,120 @@ def applyInPandasWithState( | |
| ) | ||
| return DataFrame(jdf, self.session) | ||
|
|
||
|
|
||
| def transformWithStateInPandas(self, | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| stateful_processor: StatefulProcessor, | ||
| outputStructType: Union[StructType, str], | ||
| outputMode: str, | ||
| timeMode: str) -> DataFrame: | ||
| """ | ||
| Invokes methods defined in the stateful processor used in arbitrary state API v2. | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| We allow the user to act on per-group set of input rows along with keyed state and the | ||
| user can choose to output/return 0 or more rows. | ||
|
|
||
| For a streaming dataframe, we will repeatedly invoke the interface methods for new rows | ||
| in each trigger and the user's state/state variables will be stored persistently across | ||
| invocations. | ||
|
|
||
| The `stateful_processor` should be a Python class that implements the interface defined in | ||
| pyspark.sql.streaming.stateful_processor.StatefulProcessor. | ||
|
|
||
| The `outputStructType` should be a :class:`StructType` describing the schema of all | ||
| elements in the returned value, `pandas.DataFrame`. The column labels of all elements in | ||
| returned `pandas.DataFrame` must either match the field names in the defined schema if | ||
| specified as strings, or match the field data types by position if not strings, | ||
| e.g. integer indices. | ||
|
|
||
| The size of each `pandas.DataFrame` in both the input and output can be arbitrary. The | ||
| number of `pandas.DataFrame` in both the input and output can also be arbitrary. | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| .. versionadded:: 4.0.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| stateful_processor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor` | ||
| Instance of StatefulProcessor whose functions will be invoked by the operator. | ||
| outputStructType : :class:`pyspark.sql.types.DataType` or str | ||
| The type of the output records. The value can be either a | ||
| :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. | ||
| outputMode : str | ||
| The output mode of the stateful processor. | ||
| timeMode : str | ||
| The time mode semantics of the stateful processor for timers and TTL. | ||
|
|
||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Examples | ||
| -------- | ||
| >>> import pandas as pd | ||
| >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle | ||
| >>> from pyspark.sql.types import StructType, StructField, LongType, StringType | ||
| >>> from typing import Iterator | ||
| >>> output_schema = StructType([ | ||
| ... StructField("value", LongType(), True) | ||
| ... ]) | ||
| >>> state_schema = StructType([ | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... StructField("value", StringType(), True) | ||
| ... ]) | ||
| >>> class SimpleStatefulProcessor(StatefulProcessor): | ||
| ... def init(self, handle: StatefulProcessorHandle) -> None: | ||
| ... self.value_state = handle.getValueState("testValueState", state_schema) | ||
| ... def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: | ||
| ... self.value_state.update("test_value") | ||
| ... exists = self.value_state.exists() | ||
| ... value = self.value_state.get() | ||
| ... self.value_state.clear() | ||
| ... return rows | ||
| ... def close(self) -> None: | ||
| ... pass | ||
|
||
| ... | ||
| >>> df.groupBy("value").transformWithStateInPandas(stateful_processor = | ||
| ... SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update", | ||
| ... timeMode="None") # doctest: +SKIP | ||
|
|
||
| Notes | ||
| ----- | ||
| This function requires a full shuffle. | ||
|
|
||
| This API is experimental. | ||
| """ | ||
|
|
||
| from pyspark.sql import GroupedData | ||
| from pyspark.sql.functions import pandas_udf | ||
| assert isinstance(self, GroupedData) | ||
|
|
||
| def transformWithStateUDF(state_api_client: StateApiClient, key: Any, | ||
|
||
| inputRows: Iterator["PandasDataFrameLike"]) -> Iterator["PandasDataFrameLike"]: | ||
| handle = StatefulProcessorHandle(state_api_client) | ||
|
|
||
| if (state_api_client.handle_state == StatefulProcessorHandleState.CREATED): | ||
| stateful_processor.init(handle) | ||
| state_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED) | ||
|
|
||
| state_api_client.set_implicit_key(str(key[0])) | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| result = stateful_processor.handleInputRows(key, inputRows) | ||
| state_api_client.remove_implicit_key() | ||
|
|
||
| return result | ||
|
|
||
| if isinstance(outputStructType, str): | ||
| outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) | ||
|
|
||
| udf = pandas_udf( | ||
| transformWithStateUDF, # type: ignore[call-overload] | ||
| returnType=outputStructType, | ||
| functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE, | ||
| ) | ||
| df = self._df | ||
| udf_column = udf(*[df[col] for col in df.columns]) | ||
|
|
||
| jdf = self._jgd.transformWithStateInPandas( | ||
| udf_column._jc.expr(), | ||
| self.session._jsparkSession.parseDataType(outputStructType.json()), | ||
| outputMode, | ||
| timeMode, | ||
| ) | ||
| return DataFrame(jdf, self.session) | ||
|
|
||
|
|
||
| def applyInArrow( | ||
| self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] | ||
| ) -> "DataFrame": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,14 @@ | |
| Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. | ||
| """ | ||
|
|
||
| from enum import Enum | ||
| from itertools import groupby | ||
| import os | ||
| import socket | ||
| from typing import Any | ||
| from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError | ||
| from pyspark.loose_version import LooseVersion | ||
| from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer | ||
| from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer, write_with_length | ||
| from pyspark.sql.pandas.types import ( | ||
| from_arrow_type, | ||
| to_arrow_type, | ||
|
|
@@ -1116,3 +1121,71 @@ def init_stream_yield_batches(batches): | |
| batches_to_write = init_stream_yield_batches(serialize_batches()) | ||
|
|
||
| return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) | ||
|
|
||
|
|
||
| class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Serializer used by Python worker to evaluate UDF for transformWithStateInPandasSerializer. | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Parameters | ||
| ---------- | ||
| timezone : str | ||
| A timezone to respect when handling timestamp values | ||
| safecheck : bool | ||
| If True, conversion from Arrow to Pandas checks for overflow/truncation | ||
| assign_cols_by_name : bool | ||
| If True, then Pandas DataFrames will get columns by name | ||
| arrow_max_records_per_batch : int | ||
| Limit of the number of records that can be written to a single ArrowRecordBatch in memory. | ||
| """ | ||
|
|
||
| def __init__( | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self, | ||
| timezone, | ||
| safecheck, | ||
| assign_cols_by_name, | ||
| arrow_max_records_per_batch): | ||
| super( | ||
| TransformWithStateInPandasSerializer, | ||
| self | ||
| ).__init__(timezone, safecheck, assign_cols_by_name) | ||
| self.arrow_max_records_per_batch = arrow_max_records_per_batch | ||
| self.key_offsets = None | ||
|
|
||
| # Nothing special here, we need to create the handle and read | ||
| # data in groups. | ||
| def load_stream(self, stream): | ||
| """ | ||
| Read ArrowRecordBatches from stream, deserialize them to populate a list of pair | ||
| (data chunk, state), and convert the data into a list of pandas.Series. | ||
|
|
||
| Please refer the doc of inner function `gen_data_and_state` for more details how | ||
| this function works in overall. | ||
|
|
||
| In addition, this function further groups the return of `gen_data_and_state` by the state | ||
| instance (same semantic as grouping by grouping key) and produces an iterator of data | ||
| chunks for each group, so that the caller can lazily materialize the data chunk. | ||
|
||
| """ | ||
| import pyarrow as pa | ||
|
|
||
| def generate_data_batches(batches): | ||
| for batch in batches: | ||
| data_pandas = [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] | ||
|
||
| key_series = [data_pandas[o] for o in self.key_offsets] | ||
| batch_key = tuple(s[0] for s in key_series) | ||
| yield (batch_key, data_pandas) | ||
|
|
||
| _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) | ||
| data_batches = generate_data_batches(_batches) | ||
|
|
||
| for k, g in groupby(data_batches, key=lambda x: x[0]): | ||
| yield (k, g) | ||
|
|
||
|
|
||
| def dump_stream(self, iterator, stream): | ||
| """ | ||
| Read through an iterator of (iterator of pandas DataFrame, state), serialize them to Arrow | ||
| RecordBatches, and write batches to stream. | ||
| """ | ||
| result = [(b, t) for x in iterator for y, t in x for b in y] | ||
| super().dump_stream(result, stream) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| syntax = "proto3"; | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| package pyspark.sql.streaming; | ||
|
|
||
| message StateRequest { | ||
|
||
| int32 version = 1; | ||
| oneof method { | ||
| StatefulProcessorCall statefulProcessorCall = 2; | ||
| StateVariableRequest stateVariableRequest = 3; | ||
| ImplicitGroupingKeyRequest implicitGroupingKeyRequest = 4; | ||
| } | ||
| } | ||
|
|
||
| message StateResponse { | ||
| int32 statusCode = 1; | ||
| string errorMessage = 2; | ||
| } | ||
|
|
||
| message StatefulProcessorCall { | ||
| oneof method { | ||
| SetHandleState setHandleState = 1; | ||
| StateCallCommand getValueState = 2; | ||
| StateCallCommand getListState = 3; | ||
| StateCallCommand getMapState = 4; | ||
| } | ||
| } | ||
|
|
||
| message StateVariableRequest { | ||
| oneof method { | ||
| ValueStateCall valueStateCall = 1; | ||
| } | ||
| } | ||
|
|
||
| message ImplicitGroupingKeyRequest { | ||
| oneof method { | ||
| SetImplicitKey setImplicitKey = 1; | ||
| RemoveImplicitKey removeImplicitKey = 2; | ||
| } | ||
| } | ||
|
|
||
| message StateCallCommand { | ||
| string stateName = 1; | ||
| string schema = 2; | ||
| } | ||
|
|
||
| message ValueStateCall { | ||
| string stateName = 1; | ||
| oneof method { | ||
| Exists exists = 2; | ||
| Get get = 3; | ||
| ValueStateUpdate valueStateUpdate = 4; | ||
| Clear clear = 5; | ||
| } | ||
| } | ||
|
|
||
| message SetImplicitKey { | ||
| string key = 1; | ||
| } | ||
|
|
||
| message RemoveImplicitKey { | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| message Exists { | ||
| } | ||
|
|
||
| message Get { | ||
| } | ||
|
|
||
| message ValueStateUpdate { | ||
| string schema = 1; | ||
| bytes value = 2; | ||
| } | ||
|
|
||
| message Clear { | ||
| } | ||
|
|
||
| enum HandleState { | ||
| CREATED = 0; | ||
| INITIALIZED = 1; | ||
| DATA_PROCESSED = 2; | ||
| CLOSED = 3; | ||
| } | ||
|
|
||
| message SetHandleState { | ||
| HandleState state = 1; | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.