Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working version, will write test suites and test for composite types
  • Loading branch information
jingz-db committed Jul 8, 2024
commit 2bbd2cece16bf78282d21ded2e294e3b012af910
Original file line number Diff line number Diff line change
Expand Up @@ -17,68 +17,70 @@
package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema._
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec}

trait ColumnFamilySchemaUtils {
def getValueStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchema


def getListStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchema


def getMapStateSchema[K, V](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
hasTtl: Boolean): ColumnFamilySchema
}

object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils {

def getValueStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
stateName,
KEY_ROW_SCHEMA,
if (hasTtl) {
VALUE_ROW_SCHEMA_WITH_TTL
} else {
VALUE_ROW_SCHEMA
},
keyEncoder.schema,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
}

def getListStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
stateName,
KEY_ROW_SCHEMA,
if (hasTtl) {
VALUE_ROW_SCHEMA_WITH_TTL
} else {
VALUE_ROW_SCHEMA
},
keyEncoder.schema,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
}

def getMapStateSchema[K, V](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema)
new ColumnFamilySchemaV1(
stateName,
COMPOSITE_KEY_ROW_SCHEMA,
if (hasTtl) {
VALUE_ROW_SCHEMA_WITH_TTL
} else {
VALUE_ROW_SCHEMA
},
compositeKeySchema,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1),
Some(userKeyEnc.schema))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
import org.apache.spark.sql.types.{BinaryType, LongType, StructType}

object TransformWithStateKeyValueRowSchema {
/**
* The following are the key/value row schema used in StateStore layer.
* Key/value rows will be serialized into Binary format in `StateTypesEncoder`.
* The "real" key/value row schema will be written into state schema metadata.
*/
val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType)
val COMPOSITE_KEY_ROW_SCHEMA: StructType = new StructType()
.add("key", BinaryType)
Expand All @@ -35,6 +40,29 @@ object TransformWithStateKeyValueRowSchema {
val VALUE_ROW_SCHEMA_WITH_TTL: StructType = new StructType()
.add("value", BinaryType)
.add("ttlExpirationMs", LongType)

/**
* Helper function for passing the key/value schema to write to state schema metadata.
* Return value schema with additional TTL column if TTL is enabled.
*
* @param schema Value Schema returned by value encoder that user passed in
* @param hasTTL TTL enabled or not
* @return a schema with additional TTL column if TTL is enabled.
*/
def getValueSchemaWithTTL(schema: StructType, hasTTL: Boolean): StructType = {
if (hasTTL) {
new StructType(schema.fields).add("ttlExpirationMs", LongType)
} else schema
}

/**
* Given grouping key and user key schema, return the schema of the composite key.
*/
def getCompositeKeySchema(
groupingKeySchema: StructType,
userKeySchema: StructType): StructType = {
new StructType(groupingKeySchema.fields ++ userKeySchema.fields)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = {
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getValueStateSchema(stateName, false)
getValueStateSchema(stateName, keyExprEnc, valEncoder, false)
columnFamilySchemas.put(stateName, colFamilySchema)
null.asInstanceOf[ValueState[T]]
}
Expand All @@ -344,7 +344,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
ttlConfig: TTLConfig): ValueState[T] = {
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getValueStateSchema(stateName, true)
getValueStateSchema(stateName, keyExprEnc, valEncoder, true)
columnFamilySchemas.put(stateName, colFamilySchema)
null.asInstanceOf[ValueState[T]]
}
Expand All @@ -362,7 +362,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = {
verifyStateVarOperations("get_list_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getListStateSchema(stateName, false)
getListStateSchema(stateName, keyExprEnc, valEncoder, false)
columnFamilySchemas.put(stateName, colFamilySchema)
null.asInstanceOf[ListState[T]]
}
Expand All @@ -384,7 +384,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
ttlConfig: TTLConfig): ListState[T] = {
verifyStateVarOperations("get_list_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getListStateSchema(stateName, true)
getListStateSchema(stateName, keyExprEnc, valEncoder, true)
columnFamilySchemas.put(stateName, colFamilySchema)
null.asInstanceOf[ListState[T]]
}
Expand All @@ -406,7 +406,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
valEncoder: Encoder[V]): MapState[K, V] = {
verifyStateVarOperations("get_map_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getMapStateSchema(stateName, userKeyEnc, false)
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false)
columnFamilySchemas.put(stateName, colFamilySchema)
null.asInstanceOf[MapState[K, V]]
}
Expand All @@ -430,7 +430,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getMapStateSchema(stateName, userKeyEnc, true)
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true)
columnFamilySchemas.put(stateName, colFamilySchema)
null.asInstanceOf[MapState[K, V]]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
import org.apache.spark.sql.streaming.{StatefulProcessorHandle, TimeMode}

abstract class StatefulProcessorHandleImplBase(
timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any])
extends StatefulProcessorHandle {
timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) extends StatefulProcessorHandle {

protected var currState: StatefulProcessorHandleState = PRE_INIT

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogChec
import org.apache.spark.sql.functions.timestamp_seconds
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types._

object TransformWithStateSuiteUtils {
val NUM_SHUFFLE_PARTITIONS = 5
Expand Down Expand Up @@ -916,6 +917,13 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest {
val keyExprEncoderSer = encoderFor(Encoders.scalaInt).schema
println("keyExprEncoder here: " + JString(keyExprEncoderSer.json))
println("valueEncoder here: " + JString(Encoders.STRING.schema.json))
println("composite schema: " +
new StructType().add("key", BinaryType)
.add("userKey", BinaryType))
val keySchema = new StructType().add("key", BinaryType)
val userkeySchema = new StructType().add("userkeySchema", BinaryType)
println("composite schema copy: " +
StructType(keySchema.fields ++ userkeySchema.fields))
}
}
}