Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
a suite with composite type, why key encoder spec overwritten
  • Loading branch information
jingz-db committed Jul 8, 2024
commit 4f5185a1c1752b3a2bedc828419b8f432912f47a
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils {
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
stateName,
keyEncoder.schema,
getKeySchema(keyEncoder.schema),
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
}
Expand All @@ -65,7 +65,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils {
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
stateName,
keyEncoder.schema,
getKeySchema(keyEncoder.schema),
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,24 @@ object TransformWithStateKeyValueRowSchema {
.add("value", BinaryType)
.add("ttlExpirationMs", LongType)

/** Helper functions for passing the key/value schema to write to state schema metadata. */

/**
* Return key schema with key column name.
*/
def getKeySchema(schema: StructType): StructType = {
new StructType().add("key", schema)
}

/**
* Helper function for passing the key/value schema to write to state schema metadata.
* Return value schema with additional TTL column if TTL is enabled.
*
* @param schema Value Schema returned by value encoder that user passed in
* @param hasTTL TTL enabled or not
* @return a schema with additional TTL column if TTL is enabled.
*/
def getValueSchemaWithTTL(schema: StructType, hasTTL: Boolean): StructType = {
if (hasTTL) {
val valSchema = if (hasTTL) {
new StructType(schema.fields).add("ttlExpirationMs", LongType)
} else schema
new StructType()
.add("value", valSchema)
}

/**
Expand All @@ -61,7 +67,9 @@ object TransformWithStateKeyValueRowSchema {
def getCompositeKeySchema(
groupingKeySchema: StructType,
userKeySchema: StructType): StructType = {
new StructType(groupingKeySchema.fields ++ userKeySchema.fields)
new StructType()
.add("key", new StructType(groupingKeySchema.fields))
.add("userKey", new StructType(userKeySchema.fields))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true)
getMapStateSchema(stateName, keyExprEnc, valEncoder, userKeyEnc, true)
columnFamilySchemas.put(stateName, colFamilySchema)
null.asInstanceOf[MapState[K, V]]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ object KeyStateEncoderSpec {
asInstanceOf[List[_]].map(_.asInstanceOf[Int])
RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals)
case "PrefixKeyScanStateEncoderSpec" =>
val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[Int]
PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey)
val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt]
PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey.toInt)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ package org.apache.spark.sql.streaming
import java.io.File
import java.util.UUID

import org.json4s.JsonAST.JString
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkRuntimeException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, Encoders}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.util.stringToFile
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException}
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException, TestClass}
import org.apache.spark.sql.functions.timestamp_seconds
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
Expand Down Expand Up @@ -310,6 +309,21 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess
}
}

class StatefulProcessorWithCompositeTypes extends RunningCountStatefulProcessor {
@transient private var _listState: ListState[TestClass] = _
@transient private var _mapState: MapState[String, POJOTestClass] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong)
_listState = getHandle.getListState[TestClass](
"listState", Encoders.product[TestClass])
_mapState = getHandle.getMapState[String, POJOTestClass](
"mapState", Encoders.STRING, Encoders.bean(classOf[POJOTestClass]))
}
}

/**
* Class that adds tests for transformWithState stateful streaming operator
*/
Expand Down Expand Up @@ -906,24 +920,86 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest {

class TransformWithStateSchemaSuite extends StateStoreMetricsTest {

test("schema") {
import testImplicits._

test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
StateTypesEncoder(keySerializer = encoderFor(Encoders.scalaInt).createSerializer(),
valEncoder = Encoders.STRING, stateName = "someState", hasTtl = false)

val keyExprEncoderSer = encoderFor(Encoders.scalaInt).schema
println("keyExprEncoder here: " + JString(keyExprEncoderSer.json))
println("valueEncoder here: " + JString(Encoders.STRING.schema.json))
println("composite schema: " +
new StructType().add("key", BinaryType)
.add("userKey", BinaryType))
val keySchema = new StructType().add("key", BinaryType)
val userkeySchema = new StructType().add("userkeySchema", BinaryType)
println("composite schema copy: " +
StructType(keySchema.fields ++ userkeySchema.fields))
withTempDir { checkpointDir =>
val metadataPathPostfix = "state/0/default/_metadata"
val stateSchemaPath = new Path(checkpointDir.toString,
s"$metadataPathPostfix/schema/0")
val hadoopConf = spark.sessionState.newHadoopConf()
val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf)

val schema0 = ColumnFamilySchemaV1(
"countState",
new StructType().add("key",
new StructType().add("value", StringType)),
new StructType().add("value",
new StructType().add("value", LongType)),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
None
)
val schema1 = ColumnFamilySchemaV1(
"listState",
new StructType().add("key",
new StructType().add("value", StringType)),
new StructType().add("value",
new StructType()
.add("id", LongType)
.add("name", StringType)),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
None
)
val schema2 = ColumnFamilySchemaV1(
"mapState",
new StructType()
.add("key", new StructType().add("value", StringType))
.add("userKey", new StructType().add("value", StringType)),
new StructType().add("value",
new StructType()
.add("id", IntegerType)
.add("name", StringType)),
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1),
Option(new StructType().add("value", StringType))
)
println("print out schema0: " + schema0)

val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new StatefulProcessorWithCompositeTypes(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
AddData(inputData, "a", "b"),
CheckNewAnswer(("a", "1"), ("b", "1")),
Execute { q =>
val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath
val ssv3 = new StateSchemaV3File(hadoopConf, new Path(checkpointDir.toString,
metadataPathPostfix).toString)
val colFamilySeq = ssv3.deserialize(fm.open(schemaFilePath))

assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS ==
q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt)
assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS ==
q.lastProgress.stateOperators.head.customMetrics.get("numListStateVars").toInt)
assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS ==
q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt)

assert(colFamilySeq.length == 3)
assert(colFamilySeq.toSet == Set(
schema0, schema1, schema2
))
},
StopStream
)
}
}
}
}