Skip to content
Prev Previous commit
Next Next commit
validation
  • Loading branch information
ericm-db committed Jul 16, 2024
commit 96889053e0122ce78f6d9514efa46d23bafb046b
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -439,56 +438,62 @@ case class TransformWithStateExec(
}
}

private def checkOperatorPropEquality[T](
fieldName: String,
oldMetadataV2: OperatorStateMetadataV2,
newMetadataV2: OperatorStateMetadataV2): Unit = {
val oldJsonString = oldMetadataV2.operatorPropertiesJson
val newJsonString = newMetadataV2.operatorPropertiesJson
// verify that timeMode, outputMode are the same
implicit val formats: DefaultFormats.type = DefaultFormats
val oldJsonProps = JsonMethods.parse(oldJsonString).extract[Map[String, Any]]
val newJsonProps = JsonMethods.parse(newJsonString).extract[Map[String, Any]]
val oldProp = oldJsonProps(fieldName).asInstanceOf[T]
val newProp = newJsonProps(fieldName).asInstanceOf[T]
if (oldProp != newProp) {
throw StateStoreErrors.invalidConfigChangedAfterRestart(
fieldName,
oldProp.toString,
newProp.toString
)
}
}

private def checkStateVariableEquality(oldMetadataV2: OperatorStateMetadataV2): Unit = {
val oldJsonString = oldMetadataV2.operatorPropertiesJson
implicit val formats: DefaultFormats.type = DefaultFormats
val oldJsonProps = JsonMethods.parse(oldJsonString).extract[Map[String, Any]]
// compare state variable infos
val oldStateVariableInfos = oldJsonProps("stateVariables").
asInstanceOf[List[Map[String, Any]]]
.map(TransformWithStateVariableInfo.fromMap)
val newStateVariableInfos = getStateVariableInfos()
oldStateVariableInfos.foreach { oldInfo =>
val newInfo = newStateVariableInfos.get(oldInfo.stateName)
newInfo match {
case Some(stateVarInfo) =>
if (oldInfo.stateVariableType != stateVarInfo.stateVariableType) {
throw StateStoreErrors.invalidVariableTypeChange(
stateVarInfo.stateName,
oldInfo.stateVariableType.toString,
stateVarInfo.stateVariableType.toString
)
}
case None =>
}
}
}

def validateMetadatas(
oldMetadata: OperatorStateMetadata,
newMetadata: OperatorStateMetadata): Unit = {
// if both metadatas are instance of OperatorStateMetadatV2
(oldMetadata, newMetadata) match {
case (oldMetadataV2: OperatorStateMetadataV2,
case (
oldMetadataV2: OperatorStateMetadataV2,
newMetadataV2: OperatorStateMetadataV2) =>
val oldJsonString = oldMetadataV2.operatorPropertiesJson
val newJsonString = newMetadataV2.operatorPropertiesJson
// verify that timeMode, outputMode are the same
implicit val formats: DefaultFormats.type = DefaultFormats
val oldJsonProps = JsonMethods.parse(oldJsonString).extract[Map[String, Any]]
val newJsonProps = JsonMethods.parse(newJsonString).extract[Map[String, Any]]
val oldTimeMode = oldJsonProps("timeMode").asInstanceOf[String]
val oldOutputMode = oldJsonProps("outputMode").asInstanceOf[String]
val newTimeMode = newJsonProps("timeMode").asInstanceOf[String]
val newOutputMode = newJsonProps("outputMode").asInstanceOf[String]
if (oldTimeMode != newTimeMode) {
throw StateStoreErrors.invalidConfigChangedAfterRestart(
"timeMode",
oldTimeMode,
newTimeMode
)
}
if (oldOutputMode != newOutputMode) {
throw StateStoreErrors.invalidConfigChangedAfterRestart(
"outputMode",
oldOutputMode,
newOutputMode
)
}
// compare state variable infos
val oldStateVariableInfos = oldJsonProps("stateVariables").
asInstanceOf[List[Map[String, Any]]]
.map(TransformWithStateVariableInfo.fromMap)
val newStateVariableInfos = getStateVariableInfos()
oldStateVariableInfos.foreach { oldInfo =>
val newInfo = newStateVariableInfos.get(oldInfo.stateName)
newInfo match {
case Some(stateVarInfo) =>
if (oldInfo.stateVariableType != stateVarInfo.stateVariableType) {
throw StateStoreErrors.invalidVariableTypeChange(
stateVarInfo.stateName,
oldInfo.stateVariableType.toString,
stateVarInfo.stateVariableType.toString
)
}
case None =>
}
}
checkOperatorPropEquality[String]("timeMode", oldMetadataV2, newMetadataV2)
checkOperatorPropEquality[String]("outputMode", oldMetadataV2, newMetadataV2)
checkStateVariableEquality(oldMetadataV2)
case (_, _) =>
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}

test("transformWithState - verify that OperatorStateMetadataV2" +
" file is being written correctly") {
" integrates with state-metadata source") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
Expand Down Expand Up @@ -1002,17 +1002,20 @@ class TransformWithStateSuite extends StateStoreMetricsTest
Row(0, "transformWithStateExec", "default", 5, 0L, 0L),
Row(0, "transformWithStateExec", "default", 5, 1L, 1L)
))
// need line to be unbroken, otherwise the test will fail.
// scalastyle:off
val expectedAnswer = """{"timeMode":"NoTime","outputMode":"Update","stateVariables":[{"stateName":"countState","stateVariableType":"ValueState","ttlEnabled":false}]}"""
// scalastyle:on
checkAnswer(df.select(df.metadataColumn("_operatorProperties")),
Seq(
Row("""{"timeMode":"NoTime","outputMode":"Update"}"""),
Row("""{"timeMode":"NoTime","outputMode":"Update"}""")
Row(expectedAnswer),
Row(expectedAnswer)
)
)
}
}
}


test("transformWithState - verify that metadata logs are purged") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
Expand Down