Skip to content
Closed

init #24

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9f038aa
[SPARK-50112] Moving Avro files to sql/core so they can be used by Tr…
ericm-db Oct 25, 2024
28c3dbd
moving scala to scala dir
ericm-db Oct 25, 2024
2e33fd1
adding deprecated one
ericm-db Oct 25, 2024
b037859
init
ericm-db Oct 25, 2024
c1db91d
adding enum
ericm-db Oct 25, 2024
a30a29d
feedback and test
ericm-db Oct 25, 2024
2ebf6a8
creating utils class
ericm-db Oct 25, 2024
0559480
micheal feedback
ericm-db Oct 31, 2024
d3845a5
ValueState post-refactor
ericm-db Nov 1, 2024
35b3b0d
multivalue state encoder
ericm-db Nov 1, 2024
dcf0df7
encodeToUnsafeRow avro method
ericm-db Nov 2, 2024
dfc6b1e
using correct val
ericm-db Nov 4, 2024
5b98aa6
comments
ericm-db Nov 4, 2024
0d37ffd
calling encodeUnsafeRow
ericm-db Nov 4, 2024
9a1f825
merge into upstream/master
ericm-db Nov 5, 2024
5c8dd33
Merge remote-tracking branch 'upstream/master' into avro
ericm-db Nov 5, 2024
9b8dd5d
[SPARK-50127] Implement Avro encoding for MapState and PrefixKeyScanS…
ericm-db Nov 7, 2024
448ea76
making schema conversion lazy
ericm-db Nov 7, 2024
386fbf1
batch succeeds
ericm-db Nov 7, 2024
896e24f
actually enabling ttl
ericm-db Nov 7, 2024
15c5f71
including hidden files
ericm-db Nov 7, 2024
1f5e5f7
testWithEncodingTypes
ericm-db Nov 7, 2024
1826d5a
no longer relying on unsaferow
ericm-db Nov 8, 2024
c5ef895
everything but batch works
ericm-db Nov 8, 2024
e22e1a2
splitting it up
ericm-db Nov 8, 2024
730cae0
easy feedback to address
ericm-db Nov 9, 2024
754ce6c
batch works
ericm-db Nov 9, 2024
b6dbfdb
added test suite for non-contiguous ordinals
ericm-db Nov 11, 2024
e6f0b7a
using negative/null val marker
ericm-db Nov 11, 2024
ca660c0
removing log line
ericm-db Nov 11, 2024
41de8ae
getAvroEnc
ericm-db Nov 11, 2024
c49acd2
init
ericm-db Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
everything but batch works
  • Loading branch information
ericm-db committed Nov 8, 2024
commit c5ef895875cd8d677ec70f7cf7612116d06e0c8b
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,19 @@ class IncrementalExecution(
}
}

object StateStoreColumnFamilySchemas extends SparkPlanPartialRule {
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case statefulOp: StatefulOperator =>
statefulOp match {
case transformWithStateExec: TransformWithStateExec =>
transformWithStateExec.copy(
columnFamilySchemas = transformWithStateExec.getColFamilySchemas()
)
case _ => statefulOp
}
}
}

object StateOpIdRule extends SparkPlanPartialRule {
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion,
Expand Down Expand Up @@ -552,9 +565,9 @@ class IncrementalExecution(
// The rule below doesn't change the plan but can cause the side effect that
// metadata/schema is written in the checkpoint directory of stateful operator.
planWithStateOpId transform StateSchemaAndOperatorMetadataRule.rule

simulateWatermarkPropagation(planWithStateOpId)
planWithStateOpId transform WatermarkPropagationRule.rule
val planWithStateSchemas = planWithStateOpId transform StateStoreColumnFamilySchemas.rule
simulateWatermarkPropagation(planWithStateSchemas)
planWithStateSchemas transform WatermarkPropagationRule.rule
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.ListState
import org.apache.spark.sql.types.StructType

Expand All @@ -42,7 +42,7 @@ class ListStateImpl[S](
keyExprEnc: ExpressionEncoder[Any],
valEncoder: ExpressionEncoder[Any],
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None)
avroEnc: Option[AvroEncoder] = None)
extends ListStateMetricsImpl
with ListState[S]
with Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator
Expand All @@ -38,7 +38,7 @@ import org.apache.spark.util.NextIterator
* @param metrics - metrics to be updated as part of stateful processing
* @param avroEnc - optional Avro serializer and deserializer for this state variable that
* is used by the StateStore to encode state in Avro format
* @param ttlAvroEnc - optional Avro serializer and deserializer for TTL state that
* @param secondaryIndexAvroEnc - optional Avro serializer and deserializer for TTL state that
* is used by the StateStore to encode state in Avro format
* @tparam S - data type of object that will be stored
*/
Expand All @@ -50,9 +50,10 @@ class ListStateImplWithTTL[S](
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None,
ttlAvroEnc: Option[AvroEncoderSpec] = None)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs, ttlAvroEnc)
avroEnc: Option[AvroEncoder] = None,
secondaryIndexAvroEnc: Option[AvroEncoder] = None)
extends SingleKeyTTLStateImpl(
stateName, store, keyExprEnc, batchTimestampMs, secondaryIndexAvroEnc)
with ListStateMetricsImpl
with ListState[S] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.streaming.MapState
import org.apache.spark.sql.types.StructType

Expand All @@ -44,7 +44,7 @@ class MapStateImpl[K, V](
userKeyEnc: ExpressionEncoder[Any],
valEncoder: ExpressionEncoder[Any],
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None) extends MapState[K, V] with Logging {
avroEnc: Option[AvroEncoder] = None) extends MapState[K, V] with Logging {

// Pack grouping key and user key together as a prefixed composite key
private val schemaForCompositeKeyRow: StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{MapState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand All @@ -38,25 +38,25 @@ import org.apache.spark.util.NextIterator
* @param metrics - metrics to be updated as part of stateful processing
* @param avroEnc - optional Avro serializer and deserializer for this state variable that
* is used by the StateStore to encode state in Avro format
* @param ttlAvroEnc - optional Avro serializer and deserializer for TTL state that
* @param secondaryIndexAvroEnc - optional Avro serializer and deserializer for TTL state that
* is used by the StateStore to encode state in Avro format
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
* @return - instance of MapState of type [K,V] that can be used to store state persistently
*/
class MapStateImplWithTTL[K, V](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
userKeyEnc: ExpressionEncoder[Any],
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoderSpec] = None,
ttlAvroEnc: Option[AvroEncoderSpec] = None)
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
userKeyEnc: ExpressionEncoder[Any],
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty,
avroEnc: Option[AvroEncoder] = None,
secondaryIndexAvroEnc: Option[AvroEncoder] = None)
extends CompositeKeyTTLStateImpl[K](stateName, store,
keyExprEnc, userKeyEnc, batchTimestampMs, ttlAvroEnc)
keyExprEnc, userKeyEnc, batchTimestampMs, secondaryIndexAvroEnc)
with MapState[K, V] with Logging {

private val stateTypesEncoder = new CompositeKeyStateEncoder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema}
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StructField, StructType}

object StateStoreColumnFamilySchemaUtils {

def apply(initializeAvroSerde: Boolean): StateStoreColumnFamilySchemaUtils =
new StateStoreColumnFamilySchemaUtils(initializeAvroSerde)


/**
* Avro uses zig-zag encoding for some fixed-length types, like Longs and Ints. For range scans
* we want to use big-endian encoding, so we need to convert the source schema to replace these
Expand Down Expand Up @@ -76,6 +75,16 @@ object StateStoreColumnFamilySchemaUtils {
*/
class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Logging {

private def getAvroSerdeForSchema(schema: StructType): (AvroSerializer, AvroDeserializer) = {
val avroType = SchemaConverters.toAvroType(schema)
val avroOptions = AvroOptions(Map.empty)
val serializer = new AvroSerializer(schema, avroType, nullable = false)
val deserializer = new AvroDeserializer(avroType, schema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
(serializer, deserializer)
}

/**
* If initializeAvroSerde is true, this method will create an Avro Serializer and Deserializer
* for a particular key and value schema.
Expand All @@ -84,30 +93,19 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Lo
keySchema: StructType,
valSchema: StructType,
suffixKeySchema: Option[StructType] = None
): Option[AvroEncoderSpec] = {
): Option[AvroEncoder] = {
if (initializeAvroSerde) {
val avroType = SchemaConverters.toAvroType(valSchema)
val avroOptions = AvroOptions(Map.empty)
val keyAvroType = SchemaConverters.toAvroType(keySchema)
val keySer = new AvroSerializer(keySchema, keyAvroType, nullable = false)
val keyDe = new AvroDeserializer(keyAvroType, keySchema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
val valueSerializer = new AvroSerializer(valSchema, avroType, nullable = false)
val valueDeserializer = new AvroDeserializer(avroType, valSchema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
val (keySer, keyDe) =
getAvroSerdeForSchema(keySchema)
val (valueSerializer, valueDeserializer) =
getAvroSerdeForSchema(valSchema)
val (suffixKeySer, suffixKeyDe) = if (suffixKeySchema.isDefined) {
val userKeyAvroType = SchemaConverters.toAvroType(suffixKeySchema.get)
val skSer = new AvroSerializer(suffixKeySchema.get, userKeyAvroType, nullable = false)
val skDe = new AvroDeserializer(userKeyAvroType, suffixKeySchema.get,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
(Some(skSer), Some(skDe))
val serde = getAvroSerdeForSchema(suffixKeySchema.get)
(Some(serde._1), Some(serde._2))
} else {
(None, None)
}
Some(AvroEncoderSpec(
Some(AvroEncoder(
keySer, keyDe, valueSerializer, valueDeserializer, suffixKeySer, suffixKeyDe))
} else {
None
Expand Down Expand Up @@ -164,6 +162,11 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Lo
)
}

// This function creates the StateStoreColFamilySchema for
// the TTL secondary index.
// Because we want to encode fixed-length types as binary types
// if we are using Avro, we need to do some schema conversion to ensure
// we can use range scan
def getTtlStateSchema(
stateName: String,
keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = {
Expand All @@ -184,6 +187,11 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Lo
)
}

// This function creates the StateStoreColFamilySchema for
// the TTL secondary index.
// Because we want to encode fixed-length types as binary types
// if we are using Avro, we need to do some schema conversion to ensure
// we can use range scan
def getTtlStateSchema(
stateName: String,
keyEncoder: ExpressionEncoder[Any],
Expand Down Expand Up @@ -221,6 +229,11 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Lo
))
}

// This function creates the StateStoreColFamilySchema for
// Timers' secondary index.
// Because we want to encode fixed-length types as binary types
// if we are using Avro, we need to do some schema conversion to ensure
// we can use range scan
def getTimerStateSchemaForSecIndex(
stateName: String,
keySchema: StructType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,10 @@ class DriverStatefulProcessorHandleImpl(
val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString)
val secIndexColFamilyName = TimerStateUtils.getSecIndexColFamilyName(timeMode.toString)
val timerEncoder = new TimerKeyEncoder(keyExprEnc)
val colFamilySchema = schemaUtils.
getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow)
val secIndexColFamilySchema = schemaUtils.
getTimerStateSchemaForSecIndex(secIndexColFamilyName,
val colFamilySchema = schemaUtils
.getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow)
val secIndexColFamilySchema = schemaUtils
.getTimerStateSchemaForSecIndex(secIndexColFamilyName,
timerEncoder.keySchemaForSecIndex,
timerEncoder.schemaForValueRow)
columnFamilySchemas.put(stateName, colFamilySchema)
Expand Down Expand Up @@ -458,8 +458,8 @@ class DriverStatefulProcessorHandleImpl(
}

val stateEncoder = encoderFor[T]
val colFamilySchema = schemaUtils.
getListStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled)
val colFamilySchema = schemaUtils
.getListStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
Expand Down Expand Up @@ -494,8 +494,8 @@ class DriverStatefulProcessorHandleImpl(
}


val colFamilySchema = schemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, ttlEnabled)
val colFamilySchema = schemaUtils
.getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, ttlEnabled)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
getMapState(stateName, ttlEnabled = ttlEnabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils.getTtlColFamilyName
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, RangeKeyScanStateEncoderSpec, StateStore}
import org.apache.spark.sql.execution.streaming.state.{AvroEncoder, RangeKeyScanStateEncoderSpec, StateStore}
import org.apache.spark.sql.types._

object StateTTLSchema {
Expand Down Expand Up @@ -81,7 +81,7 @@ abstract class SingleKeyTTLStateImpl(
store: StateStore,
keyExprEnc: ExpressionEncoder[Any],
ttlExpirationMs: Long,
avroEnc: Option[AvroEncoderSpec] = None)
avroEnc: Option[AvroEncoder] = None)
extends TTLState {

import org.apache.spark.sql.execution.streaming.StateTTLSchema._
Expand Down Expand Up @@ -202,7 +202,7 @@ abstract class CompositeKeyTTLStateImpl[K](
keyExprEnc: ExpressionEncoder[Any],
userKeyEncoder: ExpressionEncoder[Any],
ttlExpirationMs: Long,
avroEnc: Option[AvroEncoderSpec] = None)
avroEnc: Option[AvroEncoder] = None)
extends TTLState {

import org.apache.spark.sql.execution.streaming.StateTTLSchema._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ class TimerStateImpl(
store: StateStore,
timeMode: TimeMode,
keyExprEnc: ExpressionEncoder[Any],
avroEnc: Option[AvroEncoderSpec] = None,
secIndexAvroEnc: Option[AvroEncoderSpec] = None) extends Logging {
avroEnc: Option[AvroEncoder] = None,
secIndexAvroEnc: Option[AvroEncoder] = None) extends Logging {

private val EMPTY_ROW =
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
Expand Down
Loading
Loading