-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
Changes from 7 commits
b1175e4
0a98ed8
8e2b193
16e4c17
92ef716
c3eaf38
609d94e
a27f9d9
684939b
7f65fbd
c25d7da
9c8c616
c641192
8d3da4e
cc9bf95
f7df2dc
27cd169
3b5b3e5
5d910d8
df859ab
654f2f6
38832a6
0585ac0
0ee5029
6232c81
41f8234
d57633f
df9ea9e
68f7a7e
ca5216b
c9e3a7c
2320805
6e5de2e
200ec5e
dd3e46b
e8360d4
82983af
49dbc16
d4e04ea
e108f60
bae26c2
d96fa9e
92531db
d507793
5dcb4c8
37be02a
f63687f
263c087
c7b0a4f
c80b292
81276f3
5886b5c
23e54b4
2ba4fd0
2a9c20b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,14 +15,17 @@ | |
| # limitations under the License. | ||
| # | ||
| import sys | ||
| from typing import List, Union, TYPE_CHECKING, cast | ||
| from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast | ||
| import warnings | ||
|
|
||
| from pyspark.errors import PySparkTypeError | ||
| from pyspark.util import PythonEvalType | ||
| from pyspark.sql.column import Column | ||
| from pyspark.sql.dataframe import DataFrame | ||
| from pyspark.sql.functions.builtin import udf | ||
| from pyspark.sql.streaming.state import GroupStateTimeout | ||
| from pyspark.sql.streaming.state_api_client import StateApiClient, StatefulProcessorHandleState | ||
| from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle | ||
| from pyspark.sql.types import StructType, _parse_datatype_string | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -33,6 +36,7 @@ | |
| PandasCogroupedMapFunction, | ||
| ArrowGroupedMapFunction, | ||
| ArrowCogroupedMapFunction, | ||
| DataFrameLike as PandasDataFrameLike | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| from pyspark.sql.group import GroupedData | ||
|
|
||
|
|
@@ -358,6 +362,55 @@ def applyInPandasWithState( | |
| ) | ||
| return DataFrame(jdf, self.session) | ||
|
|
||
|
|
||
| def transformWithStateInPandas(self, | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| stateful_processor: StatefulProcessor, | ||
| outputStructType: Union[StructType, str], | ||
| outputMode: str, | ||
| timeMode: str) -> DataFrame: | ||
|
|
||
| from pyspark.sql import GroupedData | ||
| from pyspark.sql.functions import pandas_udf | ||
| assert isinstance(self, GroupedData) | ||
|
|
||
| def transformWithStateUDF(state_api_client: StateApiClient, key: Any, | ||
|
||
| inputRows: Iterator["PandasDataFrameLike"]) -> Iterator["PandasDataFrameLike"]: | ||
| handle = StatefulProcessorHandle(state_api_client) | ||
|
|
||
| print(f"checking handle state: {state_api_client.handle_state}") | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (state_api_client.handle_state == StatefulProcessorHandleState.CREATED): | ||
| print("initializing stateful processor") | ||
| stateful_processor.init(handle) | ||
| print("setting handle state to initialized") | ||
| state_api_client.setHandleState(StatefulProcessorHandleState.INITIALIZED) | ||
|
|
||
| print(f"handling input rows for key: {key[0]}") | ||
| state_api_client.setImplicitKey(str(key[0])) | ||
| result = stateful_processor.handleInputRows(key, inputRows) | ||
| state_api_client.removeImplicitKey() | ||
|
|
||
| return result | ||
|
|
||
| if isinstance(outputStructType, str): | ||
| outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) | ||
|
|
||
| udf = pandas_udf( | ||
| transformWithStateUDF, # type: ignore[call-overload] | ||
| returnType=outputStructType, | ||
| functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE, | ||
| ) | ||
| df = self._df | ||
| udf_column = udf(*[df[col] for col in df.columns]) | ||
|
|
||
| jdf = self._jgd.transformWithStateInPandas( | ||
| udf_column._jc.expr(), | ||
| self.session._jsparkSession.parseDataType(outputStructType.json()), | ||
| outputMode, | ||
| timeMode, | ||
| ) | ||
| return DataFrame(jdf, self.session) | ||
|
|
||
|
|
||
| def applyInArrow( | ||
| self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] | ||
| ) -> "DataFrame": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,14 @@ | |
| Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. | ||
| """ | ||
|
|
||
| from enum import Enum | ||
| from itertools import groupby | ||
| import os | ||
| import socket | ||
| from typing import Any | ||
| from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError | ||
| from pyspark.loose_version import LooseVersion | ||
| from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer | ||
| from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer, write_with_length | ||
| from pyspark.sql.pandas.types import ( | ||
| from_arrow_type, | ||
| to_arrow_type, | ||
|
|
@@ -1101,6 +1106,7 @@ def init_stream_yield_batches(batches): | |
| This function helps to ensure the requirement for Pandas UDFs - Pandas UDFs require a | ||
| START_ARROW_STREAM before the Arrow stream is sent. | ||
|
|
||
|
|
||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| START_ARROW_STREAM should be sent after creating the first record batch so in case of | ||
| an error, it can be sent back to the JVM before the Arrow stream starts. | ||
| """ | ||
|
|
@@ -1116,3 +1122,86 @@ def init_stream_yield_batches(batches): | |
| batches_to_write = init_stream_yield_batches(serialize_batches()) | ||
|
|
||
| return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) | ||
|
|
||
|
|
||
| class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__( | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self, | ||
| timezone, | ||
| safecheck, | ||
| assign_cols_by_name, | ||
| arrow_max_records_per_batch): | ||
| super( | ||
| TransformWithStateInPandasSerializer, | ||
| self | ||
| ).__init__(timezone, safecheck, assign_cols_by_name) | ||
|
|
||
| # self.state_server_port = state_server_port | ||
|
|
||
| # # open client connection to state server socket | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # self._client_socket = socket.socket() | ||
| # self._client_socket.connect(("localhost", state_server_port)) | ||
| # sockfile = self._client_socket.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) | ||
| # self.state_serializer = TransformWithStateInPandasStateSerializer(sockfile) | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.arrow_max_records_per_batch = arrow_max_records_per_batch | ||
| self.key_offsets = None | ||
|
|
||
| # Nothing special here, we need to create the handle and read | ||
| # data in groups. | ||
| def load_stream(self, stream): | ||
| import pyarrow as pa | ||
| from itertools import tee | ||
|
|
||
| def generate_data_batches(batches): | ||
| for batch in batches: | ||
| data_pandas = [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] | ||
|
||
| key_series = [data_pandas[o] for o in self.key_offsets] | ||
| batch_key = tuple(s[0] for s in key_series) | ||
| yield (batch_key, data_pandas) | ||
|
|
||
| print("Generating data batches...") | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) | ||
| data_batches = generate_data_batches(_batches) | ||
|
|
||
| print("Returning data batches...") | ||
| for k, g in groupby(data_batches, key=lambda x: x[0]): | ||
| yield (k, g) | ||
|
|
||
|
|
||
| def dump_stream(self, iterator, stream): | ||
| result = [(b, t) for x in iterator for y, t in x for b in y] | ||
| super().dump_stream(result, stream) | ||
|
|
||
| class ImplicitGroupingKeyTracker: | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self) -> None: | ||
| self._key = None | ||
|
|
||
| def setKey(self, key: Any) -> None: | ||
| self._key = key | ||
|
|
||
| def getKey(self) -> Any: | ||
| return self._key | ||
|
|
||
|
|
||
| class TransformWithStateInPandasStateSerializer: | ||
|
||
|
|
||
| def __init__(self, sockfile) -> None: | ||
| self.sockfile = sockfile | ||
| self.grouping_key_tracker = ImplicitGroupingKeyTracker() | ||
|
|
||
| def load_stream(self, stream): | ||
| pass | ||
|
|
||
| def dump_stream(self, iterator, stream): | ||
| pass | ||
|
|
||
| def send(self, proto_message): | ||
| write_with_length(proto_message, self.sockfile) | ||
| self.sockfile.flush() | ||
|
|
||
| def receive(self): | ||
| return read_int(self.sockfile) | ||
|
|
||
| def readStr(self): | ||
| return self.sockfile.readline() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| syntax = "proto3"; | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| package pyspark.sql.streaming; | ||
|
|
||
| message StateRequest { | ||
|
||
| int32 version = 1; | ||
| oneof method { | ||
| StatefulProcessorCall statefulProcessorCall = 2; | ||
| StateVariableRequest stateVariableRequest = 3; | ||
| ImplicitGroupingKeyRequest implicitGroupingKeyRequest = 4; | ||
| } | ||
| } | ||
|
|
||
| message StateResponse { | ||
| int32 statusCode = 1; | ||
| string errorMessage = 2; | ||
| } | ||
|
|
||
| message StatefulProcessorCall { | ||
| oneof method { | ||
| SetHandleState setHandleState = 1; | ||
| StateCallCommand getValueState = 2; | ||
| StateCallCommand getListState = 3; | ||
| StateCallCommand getMapState = 4; | ||
| } | ||
| } | ||
|
|
||
| message StateVariableRequest { | ||
| oneof method { | ||
| ValueStateCall valueStateCall = 1; | ||
| ListStateCall listStateCall = 2; | ||
| } | ||
| } | ||
|
|
||
| message ImplicitGroupingKeyRequest { | ||
| oneof method { | ||
| SetImplicitKey setImplicitKey = 1; | ||
| RemoveImplicitKey removeImplicitKey = 2; | ||
| } | ||
| } | ||
|
|
||
| message StateCallCommand { | ||
| string stateName = 1; | ||
| string schema = 2; | ||
| } | ||
|
|
||
| message ValueStateCall { | ||
| oneof method { | ||
| Exists exists = 1; | ||
| Get get = 2; | ||
| Update update = 3; | ||
| Clear clear = 4; | ||
| } | ||
| } | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| message ListStateCall { | ||
| oneof method { | ||
| Exists exists = 1; | ||
| Get get = 2; | ||
| Clear clear = 3; | ||
| } | ||
| } | ||
|
|
||
| message SetImplicitKey { | ||
| string key = 1; | ||
| } | ||
|
|
||
| message RemoveImplicitKey { | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| message Exists { | ||
| string stateName = 1; | ||
| } | ||
|
|
||
| message Get { | ||
| string stateName = 1; | ||
| } | ||
|
|
||
| message Update { | ||
|
||
| string stateName = 1; | ||
| string schema = 2; | ||
|
||
| bytes value = 3; | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| message Clear { | ||
| string stateName = 1; | ||
| } | ||
|
|
||
| enum HandleState { | ||
| CREATED = 0; | ||
| INITIALIZED = 1; | ||
| DATA_PROCESSED = 2; | ||
| CLOSED = 3; | ||
| } | ||
|
|
||
| message SetHandleState { | ||
| HandleState state = 1; | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.