Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feedback
  • Loading branch information
ericm-db committed Jul 9, 2024
commit 3691a16d051b9d36813278c2cecce1be9aa3e08e
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils {
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
ColumnFamilySchemaV1(
stateName,
getKeySchema(keyEncoder.schema),
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
Expand All @@ -63,7 +63,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils {
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
ColumnFamilySchemaV1(
stateName,
getKeySchema(keyEncoder.schema),
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
Expand All @@ -77,7 +77,7 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils {
valEncoder: Encoder[V],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema)
new ColumnFamilySchemaV1(
ColumnFamilySchemaV1(
stateName,
compositeKeySchema,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class IncrementalExecution(
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)

/**
* This value dictates which schema format version the state schema should be written in
* for all operators other than TransformWithState.
*/
private val STATE_SCHEMA_DEFAULT_VERSION: Int = 2

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi

def getColumnFamilySchemas: Map[String, ColumnFamilySchema] = columnFamilySchemas.toMap

/**
* Function to add the ValueState schema to the list of column family schemas.
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
* @param stateName - name of the state variable
* @param valEncoder - SQL encoder for state variable
* @tparam T - type of state variable
* @return - instance of ValueState of type T that can be used to store state persistently
*/
override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = {
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
Expand All @@ -327,17 +317,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
null.asInstanceOf[ValueState[T]]
}

/**
* Function to add the ValueStateWithTTL schema to the list of column family schemas.
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
* @param stateName - name of the state variable
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @tparam T - type of state variable
* @return - instance of ValueState of type T that can be used to store state persistently
*/
override def getValueState[T](
stateName: String,
valEncoder: Encoder[T],
Expand All @@ -349,16 +328,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
null.asInstanceOf[ValueState[T]]
}

/**
* Function to add the ListState schema to the list of column family schemas.
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
* @param stateName - name of the state variable
* @param valEncoder - SQL encoder for state variable
* @tparam T - type of state variable
* @return - instance of ListState of type T that can be used to store state persistently
*/
override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = {
verifyStateVarOperations("get_list_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
Expand All @@ -367,17 +336,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
null.asInstanceOf[ListState[T]]
}

/**
* Function to add the ListStateWithTTL schema to the list of column family schemas.
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
* @param stateName - name of the state variable
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @tparam T - type of state variable
* @return - instance of ListState of type T that can be used to store state persistently
*/
override def getListState[T](
stateName: String,
valEncoder: Encoder[T],
Expand All @@ -389,17 +347,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
null.asInstanceOf[ListState[T]]
}

/**
* Function to add the MapState schema to the list of column family schemas.
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
* @param stateName - name of the state variable
* @param userKeyEnc - spark sql encoder for the map key
* @param valEncoder - spark sql encoder for the map value
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
* @return - instance of MapState of type [K,V] that can be used to store state persistently
*/
override def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
Expand All @@ -411,18 +358,6 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
null.asInstanceOf[MapState[K, V]]
}

/**
* Function to add the MapStateWithTTL schema to the list of column family schemas.
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
* @param stateName - name of the state variable
* @param userKeyEnc - spark sql encoder for the map key
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
* @return - instance of MapState of type [K,V] that can be used to store state persistently
*/
override def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ case class StreamingSymmetricHashJoinExec(
override def validateAndMaybeEvolveStateSchema(
hadoopConf: Configuration,
batchId: Long,
stateSchemaVersion: Int
): Array[String] = {
stateSchemaVersion: Int): Array[String] = {
var result: Map[String, (StructType, StructType)] = Map.empty
// get state schema for state stores on left side of the join
result ++= SymmetricHashJoinStateManager.getSchemaForStateStores(LeftSide,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ case class TransformWithStateExec(
* and fetch the schemas of the state variables initialized in this processor.
* @return a new instance of the driver processor handle
*/
private def getDriverProcessorHandle: DriverStatefulProcessorHandleImpl = {
private def getDriverProcessorHandle(): DriverStatefulProcessorHandleImpl = {
val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode, keyEncoder)
driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT)
statefulProcessor.setHandle(driverProcessorHandle)
Expand All @@ -111,12 +111,16 @@ case class TransformWithStateExec(
* after init is called.
*/
private def getColFamilySchemas(): Map[String, ColumnFamilySchema] = {
val driverProcessorHandle = getDriverProcessorHandle
val columnFamilySchemas = driverProcessorHandle.getColumnFamilySchemas
val columnFamilySchemas = getDriverProcessorHandle().getColumnFamilySchemas
closeProcessorHandle()
columnFamilySchemas
}

/**
* This method is used for the driver-side stateful processor after we
* have collected all the necessary schemas.
* This instance of the stateful processor won't be used again.
*/
private def closeProcessorHandle(): Unit = {
statefulProcessor.close()
statefulProcessor.setHandle(null)
Expand Down Expand Up @@ -373,12 +377,11 @@ case class TransformWithStateExec(
override def validateAndMaybeEvolveStateSchema(
hadoopConf: Configuration,
batchId: Long,
stateSchemaVersion: Int):
Array[String] = {
stateSchemaVersion: Int): Array[String] = {
assert(stateSchemaVersion >= 3)
val newColumnFamilySchemas = getColFamilySchemas()
val schemaFile = new StateSchemaV3File(
hadoopConf, stateSchemaFilePath(StateStoreId.DEFAULT_STORE_NAME).toString)
hadoopConf, stateSchemaDirPath(StateStoreId.DEFAULT_STORE_NAME).toString)
// TODO: Read the schema path from the OperatorStateMetadata file
// and validate it with the new schema

Expand All @@ -402,7 +405,7 @@ case class TransformWithStateExec(
}
}

private def stateSchemaFilePath(storeName: String): Path = {
private def stateSchemaDirPath(storeName: String): Path = {
assert(storeName == StateStoreId.DEFAULT_STORE_NAME)
def stateInfo = getStateInfo
val stateCheckpointPath =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ object ColumnFamilySchemaV1 {
s"Expected Map but got ${colFamilyMap.getClass}")
val keySchema = StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String])
val valueSchema = StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String])
new ColumnFamilySchemaV1(
ColumnFamilySchemaV1(
colFamilyMap("columnFamilyName").asInstanceOf[String],
keySchema,
valueSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class StateSchemaCompatibilityChecker(
object StateSchemaCompatibilityChecker extends Logging {
val VERSION = 2


/**
* Function to check if new state store schema is compatible with the existing schema.
* @param oldSchema - old state schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVers
* The StateSchemaV3File is used to write the schema of multiple column families.
* Right now, this is primarily used for the TransformWithState operator, which supports
* multiple column families to keep the data for multiple state variables.
* We only expect ColumnFamilySchemaV1 to be written and read from this file.
* @param hadoopConf Hadoop configuration that is used to read / write metadata files.
* @param path Path to the directory that will be used for writing metadata.
*/
class StateSchemaV3File(
hadoopConf: Configuration,
path: String) {

val VERSION = 3

val metadataPath = new Path(path)

protected val fileManager: CheckpointFileManager =
Expand All @@ -51,21 +50,21 @@ class StateSchemaV3File(
fileManager.mkdirs(metadataPath)
}

def deserialize(in: InputStream): List[ColumnFamilySchema] = {
private def deserialize(in: InputStream): List[ColumnFamilySchema] = {
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()

if (!lines.hasNext) {
throw new IllegalStateException("Incomplete log file in the offset commit log")
}

val version = lines.next().trim
validateVersion(version, VERSION)
validateVersion(version, StateSchemaV3File.VERSION)

lines.map(ColumnFamilySchemaV1.fromJson).toList
}

def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = {
out.write(s"v${VERSION}".getBytes(UTF_8))
private def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = {
out.write(s"v${StateSchemaV3File.VERSION}".getBytes(UTF_8))
out.write('\n')
out.write(schemas.map(_.json).mkString("\n").getBytes(UTF_8))
}
Expand All @@ -85,7 +84,6 @@ class StateSchemaV3File(
protected def write(
batchMetadataFile: Path,
fn: OutputStream => Unit): Unit = {
// Only write metadata when the batch has not yet been written
val output = fileManager.createAtomic(batchMetadataFile, overwriteIfPossible = false)
try {
fn(output)
Expand All @@ -101,3 +99,7 @@ class StateSchemaV3File(
new Path(metadataPath, batchId.toString)
}
}

object StateSchemaV3File {
val VERSION = 3
}
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ object KeyStateEncoderSpec {
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)
case "RangeKeyScanStateEncoderSpec" =>
val orderingOrdinals = m("orderingOrdinals").
asInstanceOf[List[_]].map(_.asInstanceOf[Int])
asInstanceOf[List[_]].map(_.asInstanceOf[BigInt].toInt)
RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals)
case "PrefixKeyScanStateEncoderSpec" =>
val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import scala.util.Random
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
Expand Down Expand Up @@ -1627,6 +1629,30 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
keyRow, keySchema, valueRow, keySchema, storeConf)
}

test("test serialization and deserialization of NoPrefixKeyStateEncoderSpec") {
implicit val formats: DefaultFormats.type = DefaultFormats
val encoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]]
val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap)
assert(encoderSpec == deserializedEncoderSpec)
}

test("test serialization and deserialization of PrefixKeyScanStateEncoderSpec") {
implicit val formats: DefaultFormats.type = DefaultFormats
val encoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1)
val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]]
val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap)
assert(encoderSpec == deserializedEncoderSpec)
}

test("test serialization and deserialization of RangeKeyScanStateEncoderSpec") {
implicit val formats: DefaultFormats.type = DefaultFormats
val encoderSpec = RangeKeyScanStateEncoderSpec(keySchema, Seq(1))
val jsonMap = JsonMethods.parse(encoderSpec.json).extract[Map[String, Any]]
val deserializedEncoderSpec = KeyStateEncoderSpec.fromJson(keySchema, jsonMap)
assert(encoderSpec == deserializedEncoderSpec)
}

/** Return a new provider with a random id */
def newStoreProvider(): ProviderClass

Expand Down