Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
b1175e4
[WIP] Poc.
sahnib Jan 22, 2024
0a98ed8
Introduce Protobuf.
sahnib Feb 6, 2024
8e2b193
Fixing things.
sahnib Feb 6, 2024
16e4c17
support timeMode for python state v2 API
bogao007 Jun 20, 2024
92ef716
Add protobuf for serde
sahnib Feb 13, 2024
c3eaf38
protobuf change
bogao007 Jun 20, 2024
609d94e
Initial commit
bogao007 Jun 27, 2024
a27f9d9
better error handling, support value state with different types
bogao007 Jun 28, 2024
684939b
addressed comments
bogao007 Jul 3, 2024
7f65fbd
fix
bogao007 Jul 3, 2024
c25d7da
Added support for unix domain socket
bogao007 Jul 11, 2024
9c8c616
removed unrelated log lines, addressed part of the comments
bogao007 Jul 17, 2024
c641192
fix
bogao007 Jul 17, 2024
8d3da4e
Addressed comments
bogao007 Jul 19, 2024
cc9bf95
removed unnecessary print
bogao007 Jul 19, 2024
f7df2dc
rename
bogao007 Jul 19, 2024
27cd169
fix
bogao007 Jul 19, 2024
3b5b3e5
removed duplicate proto file
bogao007 Jul 20, 2024
5d910d8
revert unrelated changes
bogao007 Jul 20, 2024
df859ab
fix
bogao007 Jul 20, 2024
654f2f6
Added unit tests for transformWithStateInPandas
bogao007 Jul 24, 2024
38832a6
Merge branch 'master' into state-v2-initial
bogao007 Jul 24, 2024
0585ac0
fix and rename
bogao007 Jul 24, 2024
0ee5029
update test
bogao007 Jul 24, 2024
6232c81
Added lisences
bogao007 Jul 25, 2024
41f8234
fixed format issues
bogao007 Jul 25, 2024
d57633f
fix
bogao007 Jul 25, 2024
df9ea9e
fix format
bogao007 Jul 25, 2024
68f7a7e
doc
bogao007 Jul 25, 2024
ca5216b
addressed comments
bogao007 Jul 26, 2024
c9e3a7c
structured log
bogao007 Jul 26, 2024
2320805
suppress auto generated proto file
bogao007 Jul 29, 2024
6e5de2e
fix linter
bogao007 Jul 29, 2024
200ec5e
fixed dependency issue
bogao007 Jul 29, 2024
dd3e46b
make protobuf as local dependency
bogao007 Jul 30, 2024
e8360d4
fix dependency issue
bogao007 Jul 30, 2024
82983af
fix
bogao007 Jul 30, 2024
49dbc16
fix lint
bogao007 Jul 30, 2024
d4e04ea
fix
bogao007 Jul 30, 2024
e108f60
updated fix
bogao007 Jul 30, 2024
bae26c2
reformat
bogao007 Jul 30, 2024
d96fa9e
addressed comments
bogao007 Jul 31, 2024
92531db
fix linter
bogao007 Jul 31, 2024
d507793
linter
bogao007 Jul 31, 2024
5dcb4c8
addressed comments
bogao007 Aug 2, 2024
37be02a
address comment
bogao007 Aug 2, 2024
f63687f
addressed comments
bogao007 Aug 9, 2024
263c087
Merge branch 'master' into state-v2-initial
bogao007 Aug 10, 2024
c7b0a4f
address comments
bogao007 Aug 12, 2024
c80b292
address comments
bogao007 Aug 12, 2024
81276f3
address comments
bogao007 Aug 14, 2024
5886b5c
fix lint
bogao007 Aug 14, 2024
23e54b4
fix lint
bogao007 Aug 14, 2024
2ba4fd0
address comments
bogao007 Aug 14, 2024
2a9c20b
fix test
bogao007 Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
better error handling, support value state with different types
  • Loading branch information
bogao007 committed Jun 28, 2024
commit a27f9d9a98155d1d423ccbf4043a80dfa8541fcc
6 changes: 3 additions & 3 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,12 @@ def transformWithStateUDF(state_api_client: StateApiClient, key: Any,
print("initializing stateful processor")
stateful_processor.init(handle)
print("setting handle state to initialized")
state_api_client.setHandleState(StatefulProcessorHandleState.INITIALIZED)
state_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)

print(f"handling input rows for key: {key[0]}")
state_api_client.setImplicitKey(str(key[0]))
state_api_client.set_implicit_key(str(key[0]))
result = stateful_processor.handleInputRows(key, inputRows)
state_api_client.removeImplicitKey()
state_api_client.remove_implicit_key()

return result

Expand Down
31 changes: 8 additions & 23 deletions python/pyspark/sql/streaming/state_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
# place holder, will remove when actual implementation is done
# self.setHandleState(StatefulProcessorHandleState.CLOSED)

def setHandleState(self, state: StatefulProcessorHandleState) -> None:
def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add like 1-2 line comments for these functions ?

print(f"setting handle state to: {state}")
proto_state = self._get_proto_state(state)
set_handle_state = stateMessage.SetHandleState(state=proto_state)
Expand All @@ -60,7 +60,7 @@ def setHandleState(self, state: StatefulProcessorHandleState) -> None:
self.handle_state = state
print(f"setHandleState status= {status}")

def getValueState(self, state_name: str, schema: Union[StructType, str]) -> None:
def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> None:
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))

Expand All @@ -77,7 +77,7 @@ def getValueState(self, state_name: str, schema: Union[StructType, str]) -> None
status = read_int(self.sockfile)
print(f"getValueState status= {status}")

def valueStateExists(self, state_name: str) -> bool:
def value_state_exists(self, state_name: str) -> bool:
print(f"checking value state exists: {state_name}")
exists_call = stateMessage.Exists(stateName=state_name)
value_state_call = stateMessage.ValueStateCall(exists=exists_call)
Expand All @@ -92,7 +92,7 @@ def valueStateExists(self, state_name: str) -> bool:
else:
return False

def valueStateGet(self, state_name: str) -> Any:
def value_state_get(self, state_name: str) -> Any:
print(f"getting value state: {state_name}")
get_call = stateMessage.Get(stateName=state_name)
value_state_call = stateMessage.ValueStateCall(get=get_call)
Expand All @@ -107,7 +107,7 @@ def valueStateGet(self, state_name: str) -> Any:
else:
return None

def valueStateUpdate(self, state_name: str, schema: Union[StructType, str], value: str) -> None:
def value_state_update(self, state_name: str, schema: Union[StructType, str], value: str) -> None:
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
print(f"updating value state: {state_name}")
Expand All @@ -121,7 +121,7 @@ def valueStateUpdate(self, state_name: str, schema: Union[StructType, str], valu
status = read_int(self.sockfile)
print(f"valueStateUpdate status= {status}")

def valueStateClear(self, state_name: str) -> None:
def value_state_clear(self, state_name: str) -> None:
print(f"clearing value state: {state_name}")
clear_call = stateMessage.Clear(stateName=state_name)
value_state_call = stateMessage.ValueStateCall(clear=clear_call)
Expand All @@ -132,22 +132,7 @@ def valueStateClear(self, state_name: str) -> None:
status = read_int(self.sockfile)
print(f"valueStateClear status= {status}")

def getListState(self, state_name: str, schema: Union[StructType, str]) -> None:
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))

state_call_command = stateMessage.StateCallCommand()
state_call_command.stateName = state_name
state_call_command.schema = schema.json()
call = stateMessage.StatefulProcessorCall(getListState=state_call_command)

message = stateMessage.StateRequest(statefulProcessorCall=call)

self._send_proto_message(message)
status = read_int(self.sockfile)
print(f"status= {status}")

def setImplicitKey(self, key: str) -> None:
def set_implicit_key(self, key: str) -> None:
print(f"setting implicit key: {key}")
set_implicit_key = stateMessage.SetImplicitKey(key=key)
request = stateMessage.ImplicitGroupingKeyRequest(setImplicitKey=set_implicit_key)
Expand All @@ -157,7 +142,7 @@ def setImplicitKey(self, key: str) -> None:
status = read_int(self.sockfile)
print(f"setImplicitKey status= {status}")

def removeImplicitKey(self) -> None:
def remove_implicit_key(self) -> None:
print(f"removing implicit key")
remove_implicit_key = stateMessage.RemoveImplicitKey()
request = stateMessage.ImplicitGroupingKeyRequest(removeImplicitKey=remove_implicit_key)
Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def __init__(self,
self.schema = schema

def exists(self) -> bool:
return self._state_api_client.valueStateExists(self._state_name)
return self._state_api_client.value_state_exists(self._state_name)

def get(self) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, we expect Row as state value, not a pandas DataFrame. Please let me know if you are proposing pandas DataFrame for better suit for more state types.

return self._state_api_client.valueStateGet(self._state_name)
return self._state_api_client.value_state_get(self._state_name)

def update(self, new_value: Any) -> None:
self._state_api_client.valueStateUpdate(self._state_name, self.schema, new_value)
self._state_api_client.value_state_update(self._state_name, self.schema, new_value)

def clear(self) -> None:
self._state_api_client.valueStateClear(self._state_name)
self._state_api_client.value_state_clear(self._state_name)


class ListState:
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self.state_api_client = state_api_client

def getValueState(self, state_name: str, schema: Union[StructType, str]) -> ValueState:
self.state_api_client.getValueState(state_name, schema)
self.state_api_client.get_value_state(state_name, schema)
return ValueState(self.state_api_client, state_name, schema)

def getListState(self, state_name: str, schema: Union[StructType, str]) -> ListState:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package org.apache.spark.sql.execution.python

import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream}
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException}
import java.net.ServerSocket

import scala.collection.mutable
import scala.util.control.Breaks.break

import com.google.protobuf.ByteString

Expand All @@ -30,7 +29,7 @@ import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState}
import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateVariableRequest, ValueStateCall}
import org.apache.spark.sql.streaming.ValueState
import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, FloatType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, FloatType, IntegerType, LongType, StructType}

/**
* This class is used to handle the state requests from the Python side.
Expand All @@ -45,7 +44,7 @@ class TransformWithStateInPandasStateServer(
private var inputStream: DataInputStream = _
private var outputStream: DataOutputStream = _

private val valueStates = mutable.HashMap[String, ValueState[String]]()
private val valueStates = mutable.HashMap[String, ValueState[Any]]()

def run(): Unit = {
logWarning(s"Waiting for connection from Python worker")
Expand All @@ -61,29 +60,35 @@ class TransformWithStateInPandasStateServer(
while (listeningSocket.isConnected &&
statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) {

logWarning(s"reading the version")
val version = inputStream.readInt()
try {
logWarning(s"reading the version")
val version = inputStream.readInt()

if (version != -1) {
logWarning(s"version = ${version}")
assert(version == 0)
val messageLen = inputStream.readInt()
logWarning(s"parsing a message of ${messageLen} bytes")
if (version != -1) {
logWarning(s"version = ${version}")
assert(version == 0)
val messageLen = inputStream.readInt()
logWarning(s"parsing a message of ${messageLen} bytes")

val messageBytes = new Array[Byte](messageLen)
inputStream.read(messageBytes)
logWarning(s"read bytes = ${messageBytes.mkString("Array(", ", ", ")")}")
val messageBytes = new Array[Byte](messageLen)
inputStream.read(messageBytes)
logWarning(s"read bytes = ${messageBytes.mkString("Array(", ", ", ")")}")

val message = StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
val message = StateRequest.parseFrom(ByteString.copyFrom(messageBytes))

logWarning(s"read message = $message")
handleRequest(message)
logWarning(s"flush output stream")
logWarning(s"read message = $message")
handleRequest(message)
logWarning(s"flush output stream")

outputStream.flush()
outputStream.flush()
}
} catch {
case _: EOFException =>
logWarning(s"No more data to read from the socket")
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
return
}
}

logWarning(s"done from the state server thread")
}

Expand Down Expand Up @@ -121,7 +126,7 @@ class TransformWithStateInPandasStateServer(
if (message.getStateVariableRequest.getValueStateCall.getMethodCase ==
ValueStateCall.MethodCase.EXISTS) {
val stateName = message.getStateVariableRequest.getValueStateCall.getExists.getStateName
if (valueStates.contains(stateName)) {
if (valueStates.contains(stateName) && valueStates(stateName).exists()) {
logWarning(s"state $stateName exists")
outputStream.writeInt(0)
} else {
Expand All @@ -131,40 +136,53 @@ class TransformWithStateInPandasStateServer(
} else if (message.getStateVariableRequest.getValueStateCall.getMethodCase ==
ValueStateCall.MethodCase.GET) {
val stateName = message.getStateVariableRequest.getValueStateCall.getGet.getStateName
val valueState = valueStates.get(stateName).get
val valueOption = valueState.getOption()
if (valueOption.isDefined) {
outputStream.writeInt(0)
val value = valueOption.get
logWarning("got state value " + value)
val valueBytes = value.getBytes("UTF-8")
val byteLength = valueBytes.length
logWarning(s"writing value bytes of length $byteLength")
outputStream.writeInt(byteLength)
logWarning(s"writing value bytes: ${valueBytes.mkString("Array(", ", ", ")")}")
outputStream.write(valueBytes)
if (valueStates.contains(stateName)) {
val valueState = valueStates(stateName)
val valueOption = valueState.getOption()
if (valueOption.isDefined) {
outputStream.writeInt(0)
val value = valueOption.get.toString
logWarning("got state value " + value)
val valueBytes = value.getBytes("UTF-8")
val byteLength = valueBytes.length
logWarning(s"writing value bytes of length $byteLength")
outputStream.writeInt(byteLength)
logWarning(s"writing value bytes: ${valueBytes.mkString("Array(", ", ", ")")}")
outputStream.write(valueBytes)
} else {
logWarning(s"state $stateName doesn't exist")
outputStream.writeInt(1)
}
} else {
logWarning("didn't get state value")
logWarning(s"state $stateName doesn't exist")
outputStream.writeInt(1)
}
} else if (message.getStateVariableRequest.getValueStateCall.getMethodCase ==
ValueStateCall.MethodCase.UPDATE) {
val stateName = message.getStateVariableRequest.getValueStateCall.getUpdate.getStateName
val schema = message.getStateVariableRequest.getValueStateCall.getUpdate.getSchema
val value = message.getStateVariableRequest.getValueStateCall.getUpdate.getValue
val updateValueString = value.toStringUtf8
val structType = StructType.fromString(schema)
val field = structType.fields(0)
val updatedValue = castToType(updateValueString, field.dataType)
val updateRequest = message.getStateVariableRequest.getValueStateCall.getUpdate
val stateName = updateRequest.getStateName
val updateValueString = updateRequest.getValue.toStringUtf8
val dataType = StructType.fromString(updateRequest.getSchema).fields(0).dataType
val updatedValue = castToType(updateValueString, dataType)
logWarning(s"updating state $stateName with value $updatedValue and" +
s" type ${updatedValue.getClass}")
valueStates(stateName).update(updateValueString)
outputStream.writeInt(0)
if (valueStates.contains(stateName)) {
valueStates(stateName).update(updatedValue)
outputStream.writeInt(0)
} else {
logWarning(s"state $stateName doesn't exist")
outputStream.writeInt(1)
}
} else if (message.getStateVariableRequest.getValueStateCall.getMethodCase ==
ValueStateCall.MethodCase.CLEAR) {
val stateName = message.getStateVariableRequest.getValueStateCall.getClear.getStateName
valueStates(stateName).clear()
outputStream.writeInt(0)
if (valueStates.contains(stateName)) {
valueStates(stateName).clear()
outputStream.writeInt(0)
} else {
logWarning(s"state $stateName doesn't exist")
outputStream.writeInt(1)
}
} else {
throw new IllegalArgumentException("Invalid method call")
}
Expand Down Expand Up @@ -199,8 +217,8 @@ class TransformWithStateInPandasStateServer(
val structType = StructType.fromString(schema)
val field = structType.fields(0)
val encoder = getEncoder(field.dataType)
val state = statefulProcessorHandle.getValueState[String](stateName, Encoders.STRING)
// val state = statefulProcessorHandle.getValueState(stateName, encoder)
val state = statefulProcessorHandle.getValueState(stateName, encoder)
.asInstanceOf[ValueState[Any]]
valueStates.put(stateName, state)
outputStream.writeInt(0)
} else {
Expand Down Expand Up @@ -229,7 +247,6 @@ class TransformWithStateInPandasStateServer(
case DoubleType => Encoders.DOUBLE
case FloatType => Encoders.FLOAT
case BooleanType => Encoders.BOOLEAN
case StringType => Encoders.STRING
case _ => Encoders.STRING
}
}
Expand Down