Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions python/pyspark/sql/streaming/list_state_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Iterator, List, Union, cast, Tuple
from typing import Dict, Iterator, List, Tuple

from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
from pyspark.sql.types import StructType, TYPE_CHECKING
from pyspark.errors import PySparkRuntimeError
import uuid

Expand Down Expand Up @@ -105,11 +105,9 @@ def get(self, state_name: str, iterator_id: str) -> Tuple:
pandas_row = pandas_df.iloc[index]
return tuple(pandas_row)

def append_value(self, state_name: str, schema: Union[StructType, str], value: Tuple) -> None:
def append_value(self, state_name: str, schema: StructType, value: Tuple) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
bytes = self._stateful_processor_api_client._serialize_to_bytes(schema, value)
append_value_call = stateMessage.AppendValue(value=bytes)
list_state_call = stateMessage.ListStateCall(
Expand All @@ -125,13 +123,9 @@ def append_value(self, state_name: str, schema: Union[StructType, str], value: T
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}")

def append_list(
self, state_name: str, schema: Union[StructType, str], values: List[Tuple]
) -> None:
def append_list(self, state_name: str, schema: StructType, values: List[Tuple]) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
append_list_call = stateMessage.AppendList()
list_state_call = stateMessage.ListStateCall(
stateName=state_name, appendList=append_list_call
Expand All @@ -148,11 +142,9 @@ def append_list(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}")

def put(self, state_name: str, schema: Union[StructType, str], values: List[Tuple]) -> None:
def put(self, state_name: str, schema: StructType, values: List[Tuple]) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
put_call = stateMessage.ListStatePut()
list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call)
state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call)
Expand Down
18 changes: 6 additions & 12 deletions python/pyspark/sql/streaming/map_state_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Iterator, Union, cast, Tuple, Optional
from typing import Dict, Iterator, Tuple, Optional

from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
from pyspark.sql.types import StructType, TYPE_CHECKING
from pyspark.errors import PySparkRuntimeError
import uuid

Expand All @@ -31,18 +31,12 @@ class MapStateClient:
def __init__(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
user_key_schema: Union[StructType, str],
value_schema: Union[StructType, str],
user_key_schema: StructType,
value_schema: StructType,
) -> None:
self._stateful_processor_api_client = stateful_processor_api_client
if isinstance(user_key_schema, str):
self.user_key_schema = cast(StructType, _parse_datatype_string(user_key_schema))
else:
self.user_key_schema = user_key_schema
if isinstance(value_schema, str):
self.value_schema = cast(StructType, _parse_datatype_string(value_schema))
else:
self.value_schema = value_schema
self.user_key_schema = user_key_schema
self.value_schema = value_schema
# Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame
# and the index of the last row that was read.
self.user_key_value_pair_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
Expand Down
148 changes: 75 additions & 73 deletions python/pyspark/sql/streaming/proto/StateMessage_pb2.py

Large diffs are not rendered by default.

65 changes: 55 additions & 10 deletions python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,41 @@ class StateResponseWithLongTypeVal(google.protobuf.message.Message):

global___StateResponseWithLongTypeVal = StateResponseWithLongTypeVal

class StateResponseWithStringTypeVal(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

STATUSCODE_FIELD_NUMBER: builtins.int
ERRORMESSAGE_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
ADDITIONALVALUE_FIELD_NUMBER: builtins.int
statusCode: builtins.int
errorMessage: builtins.str
value: builtins.str
additionalValue: builtins.str
def __init__(
self,
*,
statusCode: builtins.int = ...,
errorMessage: builtins.str = ...,
value: builtins.str = ...,
additionalValue: builtins.str = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"additionalValue",
b"additionalValue",
"errorMessage",
b"errorMessage",
"statusCode",
b"statusCode",
"value",
b"value",
],
) -> None: ...

global___StateResponseWithStringTypeVal = StateResponseWithStringTypeVal

class StatefulProcessorCall(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down Expand Up @@ -496,32 +531,42 @@ class StateCallCommand(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

STATENAME_FIELD_NUMBER: builtins.int
SCHEMA_FIELD_NUMBER: builtins.int
MAPSTATEVALUESCHEMA_FIELD_NUMBER: builtins.int
JSONSCHEMA_FIELD_NUMBER: builtins.int
STRINGSCHEMA_FIELD_NUMBER: builtins.int
MAPSTATEVALUEJSONSCHEMA_FIELD_NUMBER: builtins.int
MAPSTATEVALUESTRINGSCHEMA_FIELD_NUMBER: builtins.int
TTL_FIELD_NUMBER: builtins.int
stateName: builtins.str
schema: builtins.str
mapStateValueSchema: builtins.str
jsonSchema: builtins.str
stringSchema: builtins.str
mapStateValueJsonSchema: builtins.str
mapStateValueStringSchema: builtins.str
@property
def ttl(self) -> global___TTLConfig: ...
def __init__(
self,
*,
stateName: builtins.str = ...,
schema: builtins.str = ...,
mapStateValueSchema: builtins.str = ...,
jsonSchema: builtins.str = ...,
stringSchema: builtins.str = ...,
mapStateValueJsonSchema: builtins.str = ...,
mapStateValueStringSchema: builtins.str = ...,
ttl: global___TTLConfig | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["ttl", b"ttl"]) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"mapStateValueSchema",
b"mapStateValueSchema",
"schema",
b"schema",
"jsonSchema",
b"jsonSchema",
"mapStateValueJsonSchema",
b"mapStateValueJsonSchema",
"mapStateValueStringSchema",
b"mapStateValueStringSchema",
"stateName",
b"stateName",
"stringSchema",
b"stringSchema",
"ttl",
b"ttl",
],
Expand Down
29 changes: 21 additions & 8 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ValueState:
"""

def __init__(
self, value_state_client: ValueStateClient, state_name: str, schema: Union[StructType, str]
self, value_state_client: ValueStateClient, state_name: str, schema: StructType
) -> None:
self._value_state_client = value_state_client
self._state_name = state_name
Expand Down Expand Up @@ -128,7 +128,7 @@ class ListState:
"""

def __init__(
self, list_state_client: ListStateClient, state_name: str, schema: Union[StructType, str]
self, list_state_client: ListStateClient, state_name: str, schema: StructType
) -> None:
self._list_state_client = list_state_client
self._state_name = state_name
Expand Down Expand Up @@ -274,8 +274,12 @@ def getValueState(
resets the expiration time to current processing time plus ttlDuration.
If ttl is not specified the state will never expire.
"""
self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms)
return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema)
schema_struct = self.stateful_processor_api_client.get_value_state(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It's not very straightforward to understand we request initialization for State and we get the schema as a result. (Arguably it's already confusing we get nothing from get_XXX, though it's following the current name convention.)

Why not just have a separate method? It's not very heavyweight even we have a separate request call(s) for schema, and it's only used for string schema.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I'm open for voices, I know this reduces one round trip and there would be people who prefers performance over cleaner code.
cc. @anishshri-db It'd be good to hear about your input.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea - probably better to have a separate API for this. As @HeartSaVioR mentioned, could we only do the conversion if string is passed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, will update, thanks!

state_name, schema, ttl_duration_ms
)
return ValueState(
ValueStateClient(self.stateful_processor_api_client), state_name, schema_struct
)

def getListState(
self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None
Expand All @@ -298,8 +302,12 @@ def getListState(
resets the expiration time to current processing time plus ttlDuration.
If ttl is not specified the state will never expire.
"""
self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms)
return ListState(ListStateClient(self.stateful_processor_api_client), state_name, schema)
schema_struct = self.stateful_processor_api_client.get_list_state(
state_name, schema, ttl_duration_ms
)
return ListState(
ListStateClient(self.stateful_processor_api_client), state_name, schema_struct
)

def getMapState(
self,
Expand Down Expand Up @@ -329,11 +337,16 @@ def getMapState(
resets the expiration time to current processing time plus ttlDuration.
If ttl is not specified the state will never expire.
"""
self.stateful_processor_api_client.get_map_state(
(
user_key_schema_struct,
value_schema_struct,
) = self.stateful_processor_api_client.get_map_state(
state_name, user_key_schema, value_schema, ttl_duration_ms
)
return MapState(
MapStateClient(self.stateful_processor_api_client, user_key_schema, value_schema),
MapStateClient(
self.stateful_processor_api_client, user_key_schema_struct, value_schema_struct
),
state_name,
)

Expand Down
Loading
Loading