Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix suites & add TTL suites
  • Loading branch information
jingz-db committed Jul 9, 2024
commit 00741ff6a088198790d624d9a17b7f9c1385c79e
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ThreadUtils, Utils}

Expand Down Expand Up @@ -292,14 +293,14 @@ object KeyStateEncoderSpec {
// match on type
m("keyStateEncoderType").asInstanceOf[String] match {
case "NoPrefixKeyStateEncoderSpec" =>
NoPrefixKeyStateEncoderSpec(keySchema)
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)
case "RangeKeyScanStateEncoderSpec" =>
val orderingOrdinals = m("orderingOrdinals").
asInstanceOf[List[_]].map(_.asInstanceOf[Int])
RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals)
case "PrefixKeyScanStateEncoderSpec" =>
val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt]
PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey.toInt)
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, numColsPrefixKey.toInt)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest {
new StructType().add("key",
new StructType().add("value", StringType)),
new StructType().add("value",
new StructType().add("value", LongType)),
new StructType().add("value", LongType, false)),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
None
)
Expand All @@ -949,7 +949,7 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest {
new StructType().add("value", StringType)),
new StructType().add("value",
new StructType()
.add("id", LongType)
.add("id", LongType, false)
.add("name", StringType)),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
None
Expand All @@ -961,12 +961,11 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest {
.add("userKey", new StructType().add("value", StringType)),
new StructType().add("value",
new StructType()
.add("id", IntegerType)
.add("id", IntegerType, false)
.add("name", StringType)),
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1),
Option(new StructType().add("value", StringType))
)
println("print out schema0: " + schema0)

val inputData = MemoryStream[String]
val result = inputData.toDS()
Expand All @@ -993,9 +992,9 @@ class TransformWithStateSchemaSuite extends StateStoreMetricsTest {
q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt)

assert(colFamilySeq.length == 3)
assert(colFamilySeq.toSet == Set(
assert(colFamilySeq.map(_.toString).toSet == Set(
schema0, schema1, schema2
))
).map(_.toString))
},
StopStream
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ package org.apache.spark.sql.streaming

import java.time.Duration

import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl, ValueStateImplWithTTL}
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, ValueStateImplWithTTL}
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaV3File}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types._

object TTLInputProcessFunction {
def processRow(
Expand Down Expand Up @@ -111,15 +115,15 @@ class ValueStateTTLProcessor(ttlConfig: TTLConfig)
}
}

case class MultipleValueStatesTTLProcessor(
class MultipleValueStatesTTLProcessor(
ttlKey: String,
noTtlKey: String,
ttlConfig: TTLConfig)
extends StatefulProcessor[String, InputEvent, OutputEvent]
with Logging {

@transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _
@transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _
@transient var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _
@transient var _valueStateWithoutTTL: ValueStateImpl[Int] = _

override def init(
outputMode: OutputMode,
Expand Down Expand Up @@ -160,6 +164,28 @@ case class MultipleValueStatesTTLProcessor(
}
}

class TTLProcessorWithCompositeTypes(
ttlKey: String,
noTtlKey: String,
ttlConfig: TTLConfig)
extends MultipleValueStatesTTLProcessor(
ttlKey: String, noTtlKey: String, ttlConfig: TTLConfig) {
@transient private var _listStateWithTTL: ListStateImplWithTTL[Int] = _
@transient private var _mapStateWithTTL: MapStateImplWithTTL[Int, String] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
super.init(outputMode, timeMode)
_listStateWithTTL = getHandle
.getListState("listState", Encoders.scalaInt, ttlConfig)
.asInstanceOf[ListStateImplWithTTL[Int]]
_mapStateWithTTL = getHandle
.getMapState("mapState", Encoders.scalaInt, Encoders.STRING, ttlConfig)
.asInstanceOf[MapStateImplWithTTL[Int, String]]
}
}

class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {

import testImplicits._
Expand All @@ -181,7 +207,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val result = inputStream.toDS()
.groupByKey(x => x.key)
.transformWithState(
MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig),
new MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig),
TimeMode.ProcessingTime(),
OutputMode.Append())

Expand Down Expand Up @@ -225,4 +251,108 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
)
}
}

test("verify StateSchemaV3 writes correct SQL schema of key/value and with TTL") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
withTempDir { checkpointDir =>
val metadataPathPostfix = "state/0/default/_metadata"
val stateSchemaPath = new Path(checkpointDir.toString,
s"$metadataPathPostfix/schema/0")
val hadoopConf = spark.sessionState.newHadoopConf()
val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf)

val schema0 = ColumnFamilySchemaV1(
"valueState",
new StructType().add("key",
new StructType().add("value", StringType)),
new StructType().add("value",
new StructType().add("value", LongType, false)
.add("ttlExpirationMs", LongType)),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
None
)
val schema1 = ColumnFamilySchemaV1(
"listState",
new StructType().add("key",
new StructType().add("value", StringType)),
new StructType().add("value",
new StructType()
.add("value", IntegerType, false)
.add("ttlExpirationMs", LongType)),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
None
)
val schema2 = ColumnFamilySchemaV1(
"mapState",
new StructType()
.add("key", new StructType().add("value", StringType))
.add("userKey", new StructType().add("value", StringType)),
new StructType().add("value",
new StructType()
.add("value", IntegerType, false)
.add("ttlExpirationMs", LongType)),
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1),
Option(new StructType().add("value", StringType))
)

val ttlKey = "k1"
val noTtlKey = "k2"
val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
val inputStream = MemoryStream[InputEvent]
val result = inputStream.toDS()
.groupByKey(x => x.key)
.transformWithState(
new TTLProcessorWithCompositeTypes(ttlKey, noTtlKey, ttlConfig),
TimeMode.ProcessingTime(),
OutputMode.Append())

val clock = new StreamManualClock
testStream(result)(
StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock,
checkpointLocation = checkpointDir.getCanonicalPath),
AddData(inputStream, InputEvent(ttlKey, "put", 1)),
AddData(inputStream, InputEvent(noTtlKey, "put", 2)),
// advance clock to trigger processing
AdvanceManualClock(1 * 1000),
CheckNewAnswer(),
Execute { q =>
println("last progress:" + q.lastProgress)
val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath
val ssv3 = new StateSchemaV3File(hadoopConf, new Path(checkpointDir.toString,
metadataPathPostfix).toString)
val colFamilySeq = ssv3.deserialize(fm.open(schemaFilePath))

assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS ==
q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt)
assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS ==
q.lastProgress.stateOperators.head.customMetrics
.get("numValueStateWithTTLVars").toInt)
assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS ==
q.lastProgress.stateOperators.head.customMetrics
.get("numListStateWithTTLVars").toInt)
assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS ==
q.lastProgress.stateOperators.head.customMetrics
.get("numMapStateWithTTLVars").toInt)

// TODO when there are two state var with the same name,
// only one schema file is preserved
assert(colFamilySeq.length == 3)
/*
assert(colFamilySeq.map(_.toString).toSet == Set(
schema0, schema1, schema2
).map(_.toString)) */

assert(colFamilySeq(1).toString == schema1.toString)
assert(colFamilySeq(2).toString == schema2.toString)
// The remaining schema file is the one without ttl
// assert(colFamilySeq.head.toString == schema0.toString)
},
StopStream
)
}
}
}
}