-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
Changes from 1 commit
b1175e4
0a98ed8
8e2b193
16e4c17
92ef716
c3eaf38
609d94e
a27f9d9
684939b
7f65fbd
c25d7da
9c8c616
c641192
8d3da4e
cc9bf95
f7df2dc
27cd169
3b5b3e5
5d910d8
df859ab
654f2f6
38832a6
0585ac0
0ee5029
6232c81
41f8234
d57633f
df9ea9e
68f7a7e
ca5216b
c9e3a7c
2320805
6e5de2e
200ec5e
dd3e46b
e8360d4
82983af
49dbc16
d4e04ea
e108f60
bae26c2
d96fa9e
92531db
d507793
5dcb4c8
37be02a
f63687f
263c087
c7b0a4f
c80b292
81276f3
5886b5c
23e54b4
2ba4fd0
2a9c20b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,12 +41,10 @@ def __init__( | |
| self._client_socket.connect(server_address) | ||
| self.sockfile = self._client_socket.makefile("rwb", | ||
| int(os.environ.get("SPARK_BUFFER_SIZE",65536))) | ||
|
||
| print(f"client is ready - connection established") | ||
| self.handle_state = StatefulProcessorHandleState.CREATED | ||
| self.utf8_deserializer = UTF8Deserializer() | ||
|
|
||
| def set_handle_state(self, state: StatefulProcessorHandleState) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add like 1-2 line comments for these functions ? |
||
| print(f"setting handle state to: {state}") | ||
| proto_state = self._get_proto_state(state) | ||
| set_handle_state = stateMessage.SetHandleState(state=proto_state) | ||
| handle_call = stateMessage.StatefulProcessorCall(setHandleState=set_handle_state) | ||
|
|
@@ -60,40 +58,33 @@ def set_handle_state(self, state: StatefulProcessorHandleState) -> None: | |
| self.handle_state = state | ||
| else: | ||
| raise Exception(f"Error setting handle state: {response_message.errorMessage}") | ||
| print(f"setHandleState status= {status}") | ||
|
|
||
| def set_implicit_key(self, key: str) -> None: | ||
| print(f"setting implicit key: {key}") | ||
| set_implicit_key = stateMessage.SetImplicitKey(key=key) | ||
| request = stateMessage.ImplicitGroupingKeyRequest(setImplicitKey=set_implicit_key) | ||
| message = stateMessage.StateRequest(implicitGroupingKeyRequest=request) | ||
|
|
||
| self._send_proto_message(message) | ||
| response_message = self._receive_proto_message() | ||
| status = response_message.statusCode | ||
| print(f"setImplicitKey status= {status}") | ||
| if (status != 0): | ||
| raise Exception(f"Error setting implicit key: {response_message.errorMessage}") | ||
|
|
||
| def remove_implicit_key(self) -> None: | ||
| print(f"removing implicit key") | ||
| remove_implicit_key = stateMessage.RemoveImplicitKey() | ||
| request = stateMessage.ImplicitGroupingKeyRequest(removeImplicitKey=remove_implicit_key) | ||
| message = stateMessage.StateRequest(implicitGroupingKeyRequest=request) | ||
|
|
||
| self._send_proto_message(message) | ||
| response_message = self._receive_proto_message() | ||
| status = response_message.statusCode | ||
| print(f"removeImplicitKey status= {status}") | ||
| if (status != 0): | ||
| raise Exception(f"Error removing implicit key: {response_message.errorMessage}") | ||
|
|
||
| def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> None: | ||
| if isinstance(schema, str): | ||
| schema = cast(StructType, _parse_datatype_string(schema)) | ||
|
|
||
| print(f"initializing value state: {state_name}") | ||
|
|
||
| state_call_command = stateMessage.StateCallCommand() | ||
| state_call_command.stateName = state_name | ||
| state_call_command.schema = schema.json() | ||
|
|
@@ -120,22 +111,18 @@ def _get_proto_state(self, | |
|
|
||
| def _send_proto_message(self, message: stateMessage.StateRequest) -> None: | ||
| serialized_msg = message.SerializeToString() | ||
| print(f"sending message -- len = {len(serialized_msg)} {str(serialized_msg)}") | ||
| write_int(0, self.sockfile) | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| write_int(len(serialized_msg), self.sockfile) | ||
| self.sockfile.write(serialized_msg) | ||
| self.sockfile.flush() | ||
|
|
||
| def _receive_proto_message(self) -> stateMessage.StateResponse: | ||
| serialized_msg = self._receive_str() | ||
| print(f"received response message -- len = {len(serialized_msg)} {str(serialized_msg)}") | ||
| # proto3 will not serialize the message if the value is default, in this case 0 | ||
| if (len(serialized_msg) == 0): | ||
| return stateMessage.StateResponse(statusCode=0) | ||
| message = stateMessage.StateResponse() | ||
| message.ParseFromString(serialized_msg.encode('utf-8')) | ||
| print(f"received response message -- status = {str(message.statusCode)}," | ||
| f" message = {message.errorMessage}") | ||
| return message | ||
|
|
||
| def _receive_str(self) -> str: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,10 @@ | |
|
|
||
|
|
||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| class ValueState: | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Class used for arbitrary stateful operations with the v2 API to capture single value state. | ||
|
||
| """ | ||
|
|
||
| def __init__(self, | ||
| value_state_client: ValueStateClient, | ||
| state_name: str, | ||
|
|
@@ -42,9 +46,19 @@ def __init__(self, | |
| self.schema = schema | ||
|
|
||
| def exists(self) -> bool: | ||
| """ | ||
| Whether state exists or not. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
|
||
| """ | ||
| return self._value_state_client.exists(self._state_name) | ||
|
|
||
| def get(self) -> Any: | ||
|
||
| """ | ||
| Get the state value if it exists. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
| """ | ||
| value_str = self._value_state_client.get(self._state_name) | ||
| columns = [field.name for field in self.schema.fields] | ||
| dtypes = {} | ||
|
|
@@ -67,35 +81,108 @@ def get(self) -> Any: | |
| return df | ||
|
|
||
| def update(self, new_value: Any) -> None: | ||
| """ | ||
| Update the value of the state. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
| """ | ||
| self._value_state_client.update(self._state_name, self.schema, new_value) | ||
|
|
||
| def clear(self) -> None: | ||
| """ | ||
| Remove this state. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
| """ | ||
| self._value_state_client.clear(self._state_name) | ||
|
|
||
|
|
||
| class StatefulProcessorHandle: | ||
| """ | ||
| Represents the operation handle provided to the stateful processor used in the arbitrary state | ||
|
||
| API v2. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| state_api_client: StateApiClient) -> None: | ||
| self.state_api_client = state_api_client | ||
|
|
||
| def getValueState(self, state_name: str, schema: Union[StructType, str]) -> ValueState: | ||
| """ | ||
| Function to create new or return existing single value state variable of given type. | ||
| The user must ensure to call this function only within the `init()` method of the | ||
| StatefulProcessor. | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| .. versionadded:: 4.0.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| state_name : str | ||
| name of the state variable | ||
| schema : :class:`pyspark.sql.types.DataType` or str | ||
| The schema of the state variable. The value can be either a | ||
| :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. | ||
| """ | ||
| self.state_api_client.get_value_state(state_name, schema) | ||
| return ValueState(ValueStateClient(self.state_api_client), state_name, schema) | ||
|
|
||
|
|
||
| class StatefulProcessor(ABC): | ||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Class that represents the arbitrary stateful logic that needs to be provided by the user to | ||
| perform stateful manipulations on keyed streams. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def init(self, handle: StatefulProcessorHandle) -> None: | ||
| """ | ||
| Function that will be invoked as the first method that allows for users to initialize all | ||
| their state variables and perform other init actions before handling data. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| handle : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessorHandle` | ||
| Handle to the stateful processor that provides access to the state store and other | ||
| stateful processing related APIs. | ||
| """ | ||
| pass | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @abstractmethod | ||
| def handleInputRows( | ||
| self, | ||
| key: Any, | ||
| rows: Iterator["PandasDataFrameLike"]) -> Iterator["PandasDataFrameLike"]: | ||
| """ | ||
| Function that will allow users to interact with input data rows along with the grouping key. | ||
| It should take parameters (key, Iterator[`pandas.DataFrame`]) and return another | ||
| Iterator[`pandas.DataFrame`]. For each group, all columns are passed together as | ||
| `pandas.DataFrame` to the function, and the returned `pandas.DataFrame` across all | ||
| invocations are combined as a :class:`DataFrame`. Note that the function should not make a | ||
| guess of the number of elements in the iterator. To process all data, the `handleInputRows` | ||
| function needs to iterate all elements and process them. On the other hand, the | ||
| `handleInputRows` function is not strictly required to iterate through all elements in the | ||
| iterator if it intends to read a part of data. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| key : Any | ||
| grouping key. | ||
| rows : iterable of :class:`pandas.DataFrame` | ||
| iterator of input rows associated with grouping key | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def close(self) -> None: | ||
| """ | ||
| Function called as the last method that allows for users to perform any cleanup or teardown | ||
| operations. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
| """ | ||
| pass | ||
Uh oh!
There was an error while loading. Please reload this page.