Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
b1175e4
[WIP] Poc.
sahnib Jan 22, 2024
0a98ed8
Introduce Protobuf.
sahnib Feb 6, 2024
8e2b193
Fixing things.
sahnib Feb 6, 2024
16e4c17
support timeMode for python state v2 API
bogao007 Jun 20, 2024
92ef716
Add protobuf for serde
sahnib Feb 13, 2024
c3eaf38
protobuf change
bogao007 Jun 20, 2024
609d94e
Initial commit
bogao007 Jun 27, 2024
a27f9d9
better error handling, support value state with different types
bogao007 Jun 28, 2024
684939b
addressed comments
bogao007 Jul 3, 2024
7f65fbd
fix
bogao007 Jul 3, 2024
c25d7da
Added support for unix domain socket
bogao007 Jul 11, 2024
9c8c616
removed unrelated log lines, addressed part of the comments
bogao007 Jul 17, 2024
c641192
fix
bogao007 Jul 17, 2024
8d3da4e
Addressed comments
bogao007 Jul 19, 2024
cc9bf95
removed unnecessary print
bogao007 Jul 19, 2024
f7df2dc
rename
bogao007 Jul 19, 2024
27cd169
fix
bogao007 Jul 19, 2024
3b5b3e5
removed duplicate proto file
bogao007 Jul 20, 2024
5d910d8
revert unrelated changes
bogao007 Jul 20, 2024
df859ab
fix
bogao007 Jul 20, 2024
654f2f6
Added unit tests for transformWithStateInPandas
bogao007 Jul 24, 2024
38832a6
Merge branch 'master' into state-v2-initial
bogao007 Jul 24, 2024
0585ac0
fix and rename
bogao007 Jul 24, 2024
0ee5029
update test
bogao007 Jul 24, 2024
6232c81
Added lisences
bogao007 Jul 25, 2024
41f8234
fixed format issues
bogao007 Jul 25, 2024
d57633f
fix
bogao007 Jul 25, 2024
df9ea9e
fix format
bogao007 Jul 25, 2024
68f7a7e
doc
bogao007 Jul 25, 2024
ca5216b
addressed comments
bogao007 Jul 26, 2024
c9e3a7c
structured log
bogao007 Jul 26, 2024
2320805
suppress auto generated proto file
bogao007 Jul 29, 2024
6e5de2e
fix linter
bogao007 Jul 29, 2024
200ec5e
fixed dependency issue
bogao007 Jul 29, 2024
dd3e46b
make protobuf as local dependency
bogao007 Jul 30, 2024
e8360d4
fix dependency issue
bogao007 Jul 30, 2024
82983af
fix
bogao007 Jul 30, 2024
49dbc16
fix lint
bogao007 Jul 30, 2024
d4e04ea
fix
bogao007 Jul 30, 2024
e108f60
updated fix
bogao007 Jul 30, 2024
bae26c2
reformat
bogao007 Jul 30, 2024
d96fa9e
addressed comments
bogao007 Jul 31, 2024
92531db
fix linter
bogao007 Jul 31, 2024
d507793
linter
bogao007 Jul 31, 2024
5dcb4c8
addressed comments
bogao007 Aug 2, 2024
37be02a
address comment
bogao007 Aug 2, 2024
f63687f
addressed comments
bogao007 Aug 9, 2024
263c087
Merge branch 'master' into state-v2-initial
bogao007 Aug 10, 2024
c7b0a4f
address comments
bogao007 Aug 12, 2024
c80b292
address comments
bogao007 Aug 12, 2024
81276f3
address comments
bogao007 Aug 14, 2024
5886b5c
fix lint
bogao007 Aug 14, 2024
23e54b4
fix lint
bogao007 Aug 14, 2024
2ba4fd0
address comments
bogao007 Aug 14, 2024
2a9c20b
fix test
bogao007 Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
removed unrelated log lines, addressed part of the comments
  • Loading branch information
bogao007 committed Jul 17, 2024
commit 9c8c6169a32961fa3237ceaed8c2a82c6f7dea7d
31 changes: 5 additions & 26 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,24 +378,7 @@ def transformWithStateInPandas(self,
invocations.

The `stateful_processor` should be a Python class that implements the interface defined in
pyspark.sql.streaming.stateful_processor. The stateful processor consists 3 functions:
`init`, `handleInputRows`, and `close`.

The `init` function will be invoked as the first method that allows for users to initialize
all their state variables and perform other init actions before handling data.

The `handleInputRows` function will allow users to interact with input data rows. It should
take parameters (key, Iterator[`pandas.DataFrame`]) and return another
Iterator[`pandas.DataFrame`]. For each group, all columns are passed together as
`pandas.DataFrame` to the `handleInputRows` function, and the returned `pandas.DataFrame`
across all invocations are combined as a :class:`DataFrame`. Note that the `handleInputRows`
function should not make a guess of the number of elements in the iterator. To process all
data, the `handleInputRows` function needs to iterate all elements and process them. On the
other hand, the `handleInputRows` function is not strictly required toiterate through all
elements in the iterator if it intends to read a part of data.

The `close` function will be called as the last method that allows for users to perform any
cleanup or teardown operations.
pyspark.sql.streaming.stateful_processor.StatefulProcessor.

The `outputStructType` should be a :class:`StructType` describing the schema of all
elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
Expand All @@ -410,13 +393,13 @@ def transformWithStateInPandas(self,

Parameters
----------
stateful_processor : StatefulProcessor
Instance of statefulProcessor whose functions will be invoked by the operator.
stateful_processor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor`
Instance of StatefulProcessor whose functions will be invoked by the operator.
outputStructType : :class:`pyspark.sql.types.DataType` or str
the type of the output records. The value can be either a
The type of the output records. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
outputMode : str
the output mode of the stateful processor.
The output mode of the stateful processor.
timeMode : str
The time mode semantics of the stateful processor for timers and TTL.

Expand Down Expand Up @@ -463,14 +446,10 @@ def transformWithStateUDF(state_api_client: StateApiClient, key: Any,
inputRows: Iterator["PandasDataFrameLike"]) -> Iterator["PandasDataFrameLike"]:
handle = StatefulProcessorHandle(state_api_client)

print(f"checking handle state: {state_api_client.handle_state}")
if (state_api_client.handle_state == StatefulProcessorHandleState.CREATED):
print("initializing stateful processor")
stateful_processor.init(handle)
print("setting handle state to initialized")
state_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)

print(f"handling input rows for key: {key[0]}")
state_api_client.set_implicit_key(str(key[0]))
result = stateful_processor.handleInputRows(key, inputRows)
state_api_client.remove_implicit_key()
Expand Down
13 changes: 0 additions & 13 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,6 @@ def load_stream(self, stream):
chunks for each group, so that the caller can lazily materialize the data chunk.
"""
import pyarrow as pa
from itertools import tee

def generate_data_batches(batches):
for batch in batches:
Expand All @@ -1180,11 +1179,9 @@ def generate_data_batches(batches):
batch_key = tuple(s[0] for s in key_series)
yield (batch_key, data_pandas)

print("Generating data batches...")
_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)

print("Returning data batches...")
for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)

Expand All @@ -1196,13 +1193,3 @@ def dump_stream(self, iterator, stream):
"""
result = [(b, t) for x in iterator for y, t in x for b in y]
super().dump_stream(result, stream)

class ImplicitGroupingKeyTracker:
def __init__(self) -> None:
self._key = None

def setKey(self, key: Any) -> None:
self._key = key

def getKey(self) -> Any:
return self._key
13 changes: 0 additions & 13 deletions python/pyspark/sql/streaming/state_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ def __init__(
self._client_socket.connect(server_address)
self.sockfile = self._client_socket.makefile("rwb",
int(os.environ.get("SPARK_BUFFER_SIZE",65536)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we use the buffer size here ?

print(f"client is ready - connection established")
self.handle_state = StatefulProcessorHandleState.CREATED
self.utf8_deserializer = UTF8Deserializer()

def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add like 1-2 line comments for these functions ?

print(f"setting handle state to: {state}")
proto_state = self._get_proto_state(state)
set_handle_state = stateMessage.SetHandleState(state=proto_state)
handle_call = stateMessage.StatefulProcessorCall(setHandleState=set_handle_state)
Expand All @@ -60,40 +58,33 @@ def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
self.handle_state = state
else:
raise Exception(f"Error setting handle state: {response_message.errorMessage}")
print(f"setHandleState status= {status}")

def set_implicit_key(self, key: str) -> None:
print(f"setting implicit key: {key}")
set_implicit_key = stateMessage.SetImplicitKey(key=key)
request = stateMessage.ImplicitGroupingKeyRequest(setImplicitKey=set_implicit_key)
message = stateMessage.StateRequest(implicitGroupingKeyRequest=request)

self._send_proto_message(message)
response_message = self._receive_proto_message()
status = response_message.statusCode
print(f"setImplicitKey status= {status}")
if (status != 0):
raise Exception(f"Error setting implicit key: {response_message.errorMessage}")

def remove_implicit_key(self) -> None:
print(f"removing implicit key")
remove_implicit_key = stateMessage.RemoveImplicitKey()
request = stateMessage.ImplicitGroupingKeyRequest(removeImplicitKey=remove_implicit_key)
message = stateMessage.StateRequest(implicitGroupingKeyRequest=request)

self._send_proto_message(message)
response_message = self._receive_proto_message()
status = response_message.statusCode
print(f"removeImplicitKey status= {status}")
if (status != 0):
raise Exception(f"Error removing implicit key: {response_message.errorMessage}")

def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> None:
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))

print(f"initializing value state: {state_name}")

state_call_command = stateMessage.StateCallCommand()
state_call_command.stateName = state_name
state_call_command.schema = schema.json()
Expand All @@ -120,22 +111,18 @@ def _get_proto_state(self,

def _send_proto_message(self, message: stateMessage.StateRequest) -> None:
serialized_msg = message.SerializeToString()
print(f"sending message -- len = {len(serialized_msg)} {str(serialized_msg)}")
write_int(0, self.sockfile)
write_int(len(serialized_msg), self.sockfile)
self.sockfile.write(serialized_msg)
self.sockfile.flush()

def _receive_proto_message(self) -> stateMessage.StateResponse:
serialized_msg = self._receive_str()
print(f"received response message -- len = {len(serialized_msg)} {str(serialized_msg)}")
# proto3 will not serialize the message if the value is default, in this case 0
if (len(serialized_msg) == 0):
return stateMessage.StateResponse(statusCode=0)
message = stateMessage.StateResponse()
message.ParseFromString(serialized_msg.encode('utf-8'))
print(f"received response message -- status = {str(message.statusCode)},"
f" message = {message.errorMessage}")
return message

def _receive_str(self) -> str:
Expand Down
87 changes: 87 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@


class ValueState:
"""
Class used for arbitrary stateful operations with the v2 API to capture single value state.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not call transformWithState as v2 API as only few people would know what is v2. Please call it by the name.

"""

def __init__(self,
value_state_client: ValueStateClient,
state_name: str,
Expand All @@ -42,9 +46,19 @@ def __init__(self,
self.schema = schema

def exists(self) -> bool:
"""
Whether state exists or not.

.. versionadded:: 4.0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding it at class level docstring should be enough

"""
return self._value_state_client.exists(self._state_name)

def get(self) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, we expect Row as state value, not a pandas DataFrame. Please let me know if you are proposing pandas DataFrame for better suit for more state types.

"""
Get the state value if it exists.

.. versionadded:: 4.0.0
"""
value_str = self._value_state_client.get(self._state_name)
columns = [field.name for field in self.schema.fields]
dtypes = {}
Expand All @@ -67,35 +81,108 @@ def get(self) -> Any:
return df

def update(self, new_value: Any) -> None:
"""
Update the value of the state.

.. versionadded:: 4.0.0
"""
self._value_state_client.update(self._state_name, self.schema, new_value)

def clear(self) -> None:
"""
Remove this state.

.. versionadded:: 4.0.0
"""
self._value_state_client.clear(self._state_name)


class StatefulProcessorHandle:
"""
Represents the operation handle provided to the stateful processor used in the arbitrary state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: transformWithState

API v2.
"""

def __init__(
self,
state_api_client: StateApiClient) -> None:
self.state_api_client = state_api_client

def getValueState(self, state_name: str, schema: Union[StructType, str]) -> ValueState:
"""
Function to create new or return existing single value state variable of given type.
The user must ensure to call this function only within the `init()` method of the
StatefulProcessor.

.. versionadded:: 4.0.0

Parameters
----------
state_name : str
name of the state variable
schema : :class:`pyspark.sql.types.DataType` or str
The schema of the state variable. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
"""
self.state_api_client.get_value_state(state_name, schema)
return ValueState(ValueStateClient(self.state_api_client), state_name, schema)


class StatefulProcessor(ABC):
"""
Class that represents the arbitrary stateful logic that needs to be provided by the user to
perform stateful manipulations on keyed streams.
"""

@abstractmethod
def init(self, handle: StatefulProcessorHandle) -> None:
"""
Function that will be invoked as the first method that allows for users to initialize all
their state variables and perform other init actions before handling data.

.. versionadded:: 4.0.0

Parameters
----------
handle : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessorHandle`
Handle to the stateful processor that provides access to the state store and other
stateful processing related APIs.
"""
pass

@abstractmethod
def handleInputRows(
self,
key: Any,
rows: Iterator["PandasDataFrameLike"]) -> Iterator["PandasDataFrameLike"]:
"""
Function that will allow users to interact with input data rows along with the grouping key.
It should take parameters (key, Iterator[`pandas.DataFrame`]) and return another
Iterator[`pandas.DataFrame`]. For each group, all columns are passed together as
`pandas.DataFrame` to the function, and the returned `pandas.DataFrame` across all
invocations are combined as a :class:`DataFrame`. Note that the function should not make a
guess of the number of elements in the iterator. To process all data, the `handleInputRows`
function needs to iterate all elements and process them. On the other hand, the
`handleInputRows` function is not strictly required to iterate through all elements in the
iterator if it intends to read a part of data.

.. versionadded:: 4.0.0

Parameters
----------
key : Any
grouping key.
rows : iterable of :class:`pandas.DataFrame`
iterator of input rows associated with grouping key
"""
pass

@abstractmethod
def close(self) -> None:
"""
Function called as the last method that allows for users to perform any cleanup or teardown
operations.

.. versionadded:: 4.0.0
"""
pass
8 changes: 0 additions & 8 deletions python/pyspark/sql/streaming/value_state_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
self._state_api_client = state_api_client

def exists(self, state_name: str) -> bool:
print(f"checking value state exists: {state_name}")
exists_call = stateMessage.Exists(stateName=state_name)
value_state_call = stateMessage.ValueStateCall(exists=exists_call)
state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call)
Expand All @@ -38,7 +37,6 @@ def exists(self, state_name: str) -> bool:
self._state_api_client._send_proto_message(message)
response_message = self._state_api_client._receive_proto_message()
status = response_message.statusCode
print(f"valueStateExists status= {status}")
if (status == 0):
return True
elif (status == -1):
Expand All @@ -48,7 +46,6 @@ def exists(self, state_name: str) -> bool:
raise Exception(f"Error checking value state exists: {response_message.errorMessage}")

def get(self, state_name: str) -> Any:
print(f"getting value state: {state_name}")
get_call = stateMessage.Get(stateName=state_name)
value_state_call = stateMessage.ValueStateCall(get=get_call)
state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call)
Expand All @@ -57,7 +54,6 @@ def get(self, state_name: str) -> Any:
self._state_api_client._send_proto_message(message)
response_message = self._state_api_client._receive_proto_message()
status = response_message.statusCode
print(f"valueStateGet status= {status}")
if (status == 0):
return self._state_api_client._receive_str()
else:
Expand All @@ -66,7 +62,6 @@ def get(self, state_name: str) -> Any:
def update(self, state_name: str, schema: Union[StructType, str], value: str) -> None:
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
print(f"updating value state: {state_name}")
byteStr = value.encode('utf-8')
update_call = stateMessage.Update(stateName=state_name, schema=schema.json(), value=byteStr)
value_state_call = stateMessage.ValueStateCall(update=update_call)
Expand All @@ -76,12 +71,10 @@ def update(self, state_name: str, schema: Union[StructType, str], value: str) ->
self._state_api_client._send_proto_message(message)
response_message = self._state_api_client._receive_proto_message()
status = response_message.statusCode
print(f"valueStateUpdate status= {status}")
if (status != 0):
raise Exception(f"Error updating value state: {response_message.errorMessage}")

def clear(self, state_name: str) -> None:
print(f"clearing value state: {state_name}")
clear_call = stateMessage.Clear(stateName=state_name)
value_state_call = stateMessage.ValueStateCall(clear=clear_call)
state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call)
Expand All @@ -90,6 +83,5 @@ def clear(self, state_name: str) -> None:
self._state_api_client._send_proto_message(message)
response_message = self._state_api_client._receive_proto_message()
status = response_message.statusCode
print(f"valueStateClear status= {status}")
if (status != 0):
raise Exception(f"Error clearing value state: {response_message.errorMessage}")
2 changes: 1 addition & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,7 +1631,7 @@ def extract_key_value_indexes(grouped_arg_offsets):
# support combining multiple UDFs.
assert num_udfs == 1

# See FlatMapGroupsInPandasExec for how arg_offsets are used to
# See TransformWithStateInPandasExec for how arg_offsets are used to
# distinguish between grouping attributes and data attributes
arg_offsets, f = read_single_udf(
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
Expand Down
Loading