Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
it works
  • Loading branch information
ericm-db committed Jul 16, 2024
commit e0d8ffb7819ccb3c046f5572295421d356bc3ad1
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3816,6 +3816,12 @@
],
"sqlState" : "42K06"
},
"STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE" : {
"message" : [
"Cannot change <stateName> to <newType> between query restarts. Please set <stateName> to <oldType>, or restart with a new checkpoint directory."
],
"sqlState" : "42K06"
},
"STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE" : {
"message" : [
"The streaming query failed to validate written state for key row.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,18 +302,28 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi

private[sql] val columnFamilySchemaUtils = ColumnFamilySchemaUtilsV1

private[sql] val stateVariableUtils = TransformWithStateVariableUtils

// Because this is only happening on the driver side, there is only
// one task modifying and accessing this map at a time
private[sql] val columnFamilySchemas: mutable.Map[String, ColumnFamilySchema] =
new mutable.HashMap[String, ColumnFamilySchema]()

private[sql] val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] =
new mutable.HashMap[String, TransformWithStateVariableInfo]()

def getColumnFamilySchemas: Map[String, ColumnFamilySchema] = columnFamilySchemas.toMap

def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = stateVariableInfos.toMap

override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = {
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getValueStateSchema(stateName, keyExprEnc, valEncoder, false)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getValueState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ValueState[T]]
}

Expand All @@ -325,6 +335,9 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val colFamilySchema = columnFamilySchemaUtils.
getValueStateSchema(stateName, keyExprEnc, valEncoder, true)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getValueState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ValueState[T]]
}

Expand All @@ -333,6 +346,9 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val colFamilySchema = columnFamilySchemaUtils.
getListStateSchema(stateName, keyExprEnc, valEncoder, false)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getListState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ListState[T]]
}

Expand All @@ -344,6 +360,9 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val colFamilySchema = columnFamilySchemaUtils.
getListStateSchema(stateName, keyExprEnc, valEncoder, true)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getListState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ListState[T]]
}

Expand All @@ -355,6 +374,9 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val colFamilySchema = columnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getMapState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[MapState[K, V]]
}

Expand All @@ -367,6 +389,9 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val colFamilySchema = columnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, valEncoder, userKeyEnc, true)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getMapState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[MapState[K, V]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -81,7 +82,8 @@ case class TransformWithStateExec(
initialStateDataAttrs: Seq[Attribute],
initialStateDeserializer: Expression,
initialState: SparkPlan)
extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec {
extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec
with Logging {

override def shortName: String = "transformWithStateExec"

Expand Down Expand Up @@ -123,6 +125,12 @@ case class TransformWithStateExec(
columnFamilySchemas
}

private def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {
val stateVariableInfos = getDriverProcessorHandle().getStateVariableInfos
closeProcessorHandle()
stateVariableInfos
}

/**
* This method is used for the driver-side stateful processor after we
* have collected all the necessary schemas.
Expand Down Expand Up @@ -450,19 +458,38 @@ case class TransformWithStateExec(
val newTimeMode = newJsonProps("timeMode").asInstanceOf[String]
val newOutputMode = newJsonProps("outputMode").asInstanceOf[String]
if (oldTimeMode != newTimeMode) {
throw new StateStoreInvalidConfigAfterRestart(
throw StateStoreErrors.invalidConfigChangedAfterRestart(
"timeMode",
oldTimeMode,
newTimeMode
)
}
if (oldOutputMode != newOutputMode) {
throw new StateStoreInvalidConfigAfterRestart(
throw StateStoreErrors.invalidConfigChangedAfterRestart(
"outputMode",
oldOutputMode,
newOutputMode
)
}
// compare state variable infos
val oldStateVariableInfos = oldJsonProps("stateVariables").
asInstanceOf[List[Map[String, Any]]]
.map(TransformWithStateVariableInfo.fromMap)
val newStateVariableInfos = getStateVariableInfos()
oldStateVariableInfos.foreach { oldInfo =>
val newInfo = newStateVariableInfos.get(oldInfo.stateName)
newInfo match {
case Some(stateVarInfo) =>
if (oldInfo.stateVariableType != stateVarInfo.stateVariableType) {
throw StateStoreErrors.invalidVariableTypeChange(
stateVarInfo.stateName,
oldInfo.stateVariableType.toString,
stateVarInfo.stateVariableType.toString
)
}
case None =>
}
}
case (_, _) =>
}
}
Expand All @@ -479,7 +506,10 @@ case class TransformWithStateExec(

val operatorPropertiesJson: JValue =
("timeMode" -> JString(timeMode.toString)) ~
("outputMode" -> JString(outputMode.toString))
("outputMode" -> JString(outputMode.toString)) ~
("stateVariables" -> getStateVariableInfos().map { case (_, stateInfo) =>
stateInfo.jsonValue
}.arr)

val json = compact(render(operatorPropertiesJson))
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming

import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.sql.execution.streaming.StateVariableType.StateVariableType

// Enum of possible State Variable types
object StateVariableType extends Enumeration {
type StateVariableType = Value
val ValueState, ListState, MapState = Value
}

case class TransformWithStateVariableInfo(
stateName: String,
stateVariableType: StateVariableType,
ttlEnabled: Boolean) {
def jsonValue: JValue = {
("stateName" -> JString(stateName)) ~
("stateVariableType" -> JString(stateVariableType.toString)) ~
("ttlEnabled" -> JBool(ttlEnabled))
}

def json: String = {
compact(render(jsonValue))
}
}

object TransformWithStateVariableInfo {

def fromJson(json: String): TransformWithStateVariableInfo = {
implicit val formats: DefaultFormats.type = DefaultFormats
val parsed = JsonMethods.parse(json).extract[Map[String, Any]]
fromMap(parsed)
}

def fromMap(map: Map[String, Any]): TransformWithStateVariableInfo = {
val stateName = map("stateName").asInstanceOf[String]
val stateVariableType = StateVariableType.withName(
map("stateVariableType").asInstanceOf[String])
val ttlEnabled = map("ttlEnabled").asInstanceOf[Boolean]
TransformWithStateVariableInfo(stateName, stateVariableType, ttlEnabled)
}
}
object TransformWithStateVariableUtils {
def getValueState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.ValueState, ttlEnabled)
}

def getListState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.ListState, ttlEnabled)
}

def getMapState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.MapState, ttlEnabled)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ object StateStoreErrors {
StateStoreInvalidConfigAfterRestart = {
new StateStoreInvalidConfigAfterRestart(configName, oldConfig, newConfig)
}

def invalidVariableTypeChange(stateName: String, oldType: String, newType: String):
StateStoreInvalidVariableTypeChange = {
new StateStoreInvalidVariableTypeChange(stateName, oldType, newType)
}
}

class StateStoreInvalidConfigAfterRestart(configName: String, oldConfig: String, newConfig: String)
Expand All @@ -190,6 +195,16 @@ class StateStoreInvalidConfigAfterRestart(configName: String, oldConfig: String,
)
)

class StateStoreInvalidVariableTypeChange(stateName: String, oldType: String, newType: String)
extends SparkUnsupportedOperationException(
errorClass = "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE",
messageParameters = Map(
"stateName" -> stateName,
"oldType" -> oldType,
"newType" -> newType
)
)

class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String)
extends SparkUnsupportedOperationException(
errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming

import java.io.File
import java.util.UUID

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkRuntimeException
Expand All @@ -28,7 +27,7 @@ import org.apache.spark.sql.{Dataset, Encoders, Row}
import org.apache.spark.sql.catalyst.util.stringToFile
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, OperatorInfoV1, OperatorStateMetadataV2, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreInvalidConfigAfterRestart, StateStoreMetadataV2, StateStoreMultipleColumnFamiliesNotSupportedException, StateStoreValueSchemaNotCompatible, TestClass}
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, OperatorInfoV1, OperatorStateMetadataV2, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaV3File, StateStoreInvalidConfigAfterRestart, StateStoreInvalidVariableTypeChange, StateStoreMetadataV2, StateStoreMultipleColumnFamiliesNotSupportedException, StateStoreValueSchemaNotCompatible, StatefulProcessorCannotPerformOperationWithInvalidHandleState, TestClass}
import org.apache.spark.sql.functions.timestamp_seconds
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
Expand Down Expand Up @@ -64,6 +63,29 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S
}
}

// Class to test that changing between Value and List State fails
// between query runs
class RunningCountListStatefulProcessor
extends StatefulProcessor[String, String, (String, String)]
with Logging {
@transient protected var _countState: ListState[Long] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_countState = getHandle.getListState[Long](
"countState", Encoders.scalaLong)
}

override def handleInputRows(
key: String,
inputRows: Iterator[String],
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = {
Iterator.empty
}
}

class RunningCountStatefulProcessorInt extends StatefulProcessor[String, String, (String, String)]
with Logging {
@transient protected var _countState: ValueState[Int] = _
Expand Down Expand Up @@ -1175,7 +1197,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}

test("test that different timeMode, outputMode after query restart fails") {
test("test that different outputMode after query restart fails") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
Expand Down Expand Up @@ -1211,6 +1233,41 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}

test("test that changing between different state variable types fails") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
withTempDir { checkpointDir =>
val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
AddData(inputData, "a"),
CheckNewAnswer(("a", "1")),
StopStream
)
val result2 = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountListStatefulProcessor(),
TimeMode.None(),
OutputMode.Update())
testStream(result2, OutputMode.Update())(
StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
AddData(inputData, "a"),
ExpectFailure[StateStoreInvalidVariableTypeChange] { t =>
assert(t.getMessage.contains("Cannot change countState"))
}
)
}
}
}

test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
Expand Down