Skip to content
Closed

init #24

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9f038aa
[SPARK-50112] Moving Avro files to sql/core so they can be used by Tr…
ericm-db Oct 25, 2024
28c3dbd
moving scala to scala dir
ericm-db Oct 25, 2024
2e33fd1
adding deprecated one
ericm-db Oct 25, 2024
b037859
init
ericm-db Oct 25, 2024
c1db91d
adding enum
ericm-db Oct 25, 2024
a30a29d
feedback and test
ericm-db Oct 25, 2024
2ebf6a8
creating utils class
ericm-db Oct 25, 2024
0559480
micheal feedback
ericm-db Oct 31, 2024
d3845a5
ValueState post-refactor
ericm-db Nov 1, 2024
35b3b0d
multivalue state encoder
ericm-db Nov 1, 2024
dcf0df7
encodeToUnsafeRow avro method
ericm-db Nov 2, 2024
dfc6b1e
using correct val
ericm-db Nov 4, 2024
5b98aa6
comments
ericm-db Nov 4, 2024
0d37ffd
calling encodeUnsafeRow
ericm-db Nov 4, 2024
9a1f825
merge into upstream/master
ericm-db Nov 5, 2024
5c8dd33
Merge remote-tracking branch 'upstream/master' into avro
ericm-db Nov 5, 2024
9b8dd5d
[SPARK-50127] Implement Avro encoding for MapState and PrefixKeyScanS…
ericm-db Nov 7, 2024
448ea76
making schema conversion lazy
ericm-db Nov 7, 2024
386fbf1
batch succeeds
ericm-db Nov 7, 2024
896e24f
actually enabling ttl
ericm-db Nov 7, 2024
15c5f71
including hidden files
ericm-db Nov 7, 2024
1f5e5f7
testWithEncodingTypes
ericm-db Nov 7, 2024
1826d5a
no longer relying on unsaferow
ericm-db Nov 8, 2024
c5ef895
everything but batch works
ericm-db Nov 8, 2024
e22e1a2
splitting it up
ericm-db Nov 8, 2024
730cae0
easy feedback to address
ericm-db Nov 9, 2024
754ce6c
batch works
ericm-db Nov 9, 2024
b6dbfdb
added test suite for non-contiguous ordinals
ericm-db Nov 11, 2024
e6f0b7a
using negative/null val marker
ericm-db Nov 11, 2024
ca660c0
removing log line
ericm-db Nov 11, 2024
41de8ae
getAvroEnc
ericm-db Nov 11, 2024
c49acd2
init
ericm-db Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feedback and test
  • Loading branch information
ericm-db committed Oct 25, 2024
commit a30a29d6c7d74131a4bf737f5d23d48914a3adae
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.avro.Schema.Type._
import org.apache.avro.generic._
import org.apache.avro.util.Utf8

import org.apache.spark.sql.avro.AvroUtils.{ nonNullUnionBranches, toFieldStr, AvroMatchedField}
import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField}
import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) {
val avroOptions = AvroOptions(Map.empty)
val keyAvroType = SchemaConverters.toAvroType(keySchema)
val keySer = new AvroSerializer(keySchema, keyAvroType, nullable = false)
val ser = new AvroSerializer(valSchema, avroType, nullable = false)
val de = new AvroDeserializer(avroType, valSchema,
val valueSerializer = new AvroSerializer(valSchema, avroType, nullable = false)
val valueDeserializer = new AvroDeserializer(avroType, valSchema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
Some(AvroEncoderSpec(keySer, ser, de))
Some(AvroEncoderSpec(keySer, valueSerializer, valueDeserializer))
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,19 @@ class AvroTypesEncoder[V](
valEncoder: Encoder[V],
stateName: String,
hasTtl: Boolean,
avroSerde: Option[AvroEncoderSpec]) extends StateTypesEncoder[V, Array[Byte]] {
avroEnc: Option[AvroEncoderSpec]) extends StateTypesEncoder[V, Array[Byte]] {

val out = new ByteArrayOutputStream
private lazy val out = new ByteArrayOutputStream

/** Variables reused for value conversions between spark sql and object */
private val keySerializer = keyEncoder.createSerializer()
private val valExpressionEnc = encoderFor(valEncoder)
private val objToRowSerializer = valExpressionEnc.createSerializer()
private val rowToObjDeserializer = valExpressionEnc.resolveAndBind().createDeserializer()

// case class -> dataType
private val keySchema = keyEncoder.schema
// dataType -> avroType
private val keyAvroType = SchemaConverters.toAvroType(keySchema)

// case class -> dataType
Expand All @@ -211,9 +213,10 @@ class AvroTypesEncoder[V](
if (keyOption.isEmpty) {
throw StateStoreErrors.implicitKeyNotFound(stateName)
}
assert(avroEnc.isDefined)

val keyRow = keySerializer.apply(keyOption.get).copy() // V -> InternalRow
val avroData = avroSerde.get.keySerializer.serialize(keyRow) // InternalRow -> GenericDataRecord
val avroData = avroEnc.get.keySerializer.serialize(keyRow) // InternalRow -> GenericDataRecord

out.reset()
val encoder = EncoderFactory.get().directBinaryEncoder(out, null)
Expand All @@ -225,9 +228,10 @@ class AvroTypesEncoder[V](
}

override def encodeValue(value: V): Array[Byte] = {
assert(avroEnc.isDefined)
val objRow: InternalRow = objToRowSerializer.apply(value).copy() // V -> InternalRow
val avroData =
avroSerde.get.valueSerializer.serialize(objRow) // InternalRow -> GenericDataRecord
avroEnc.get.valueSerializer.serialize(objRow) // InternalRow -> GenericDataRecord
out.reset()

val encoder = EncoderFactory.get().directBinaryEncoder(out, null)
Expand All @@ -240,16 +244,18 @@ class AvroTypesEncoder[V](
}

override def decodeValue(row: Array[Byte]): V = {
assert(avroEnc.isDefined)
val reader = new GenericDatumReader[Any](valueAvroType)
val decoder = DecoderFactory.get().binaryDecoder(row, 0, row.length, null)
val genericData = reader.read(null, decoder) // bytes -> GenericDataRecord
val internalRow = avroSerde.get.valueDeserializer.deserialize(
val internalRow = avroEnc.get.valueDeserializer.deserialize(
genericData).orNull.asInstanceOf[InternalRow] // GenericDataRecord -> InternalRow
if (hasTtl) {
rowToObjDeserializer.apply(internalRow.getStruct(0, valEncoder.schema.length))
} else rowToObjDeserializer.apply(internalRow)
}

// TODO: Implement the below methods for TTL
override def encodeValue(value: V, expirationMs: Long): Array[Byte] = {
throw new UnsupportedOperationException
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,24 @@ class StatefulProcessorHandleImpl(
resultState
}

// This method is for unit-testing ListState, as the avroEnc will not be
// populated unless the handle is created through the TransformWithStateExec operator
private[sql] def getListStateWithAvro[T](
stateName: String,
valEncoder: Encoder[T],
useAvro: Boolean): ListState[T] = {
verifyStateVarOperations("get_list_state", CREATED)
val avroEnc = if (useAvro) {
new StateStoreColumnFamilySchemaUtils(true).getListStateSchema[T](
stateName, keyEncoder, valEncoder, hasTtl = false).avroEnc
} else {
None
}
val resultState = new ListStateImpl[T](
store, stateName, keyEncoder, valEncoder, avroEnc)
resultState
}

/**
* Function to create new or return existing list state variable of given type
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ class ListStateSuite extends StateVariableSuiteBase {
// overwrite useMultipleValuesPerKey in base suite to be true for list state
override def useMultipleValuesPerKey: Boolean = true

private def testMapStateWithNullUserKey()(runListOps: ListState[Long] => Unit): Unit = {
private def testMapStateWithNullUserKey(useAvro: Boolean)
(runListOps: ListState[Long] => Unit): Unit = {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())

val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong)
val listState: ListState[Long] = handle.getListStateWithAvro[Long](
"listState", Encoders.scalaLong, useAvro)

ImplicitGroupingKeyTracker.setImplicitKey("test_key")
val e = intercept[SparkIllegalArgumentException] {
Expand All @@ -57,8 +59,8 @@ class ListStateSuite extends StateVariableSuiteBase {
}

Seq("appendList", "put").foreach { listImplFunc =>
test(s"Test list operation($listImplFunc) with null") {
testMapStateWithNullUserKey() { listState =>
testWithAvroEnc(s"Test list operation($listImplFunc) with null") { useAvro =>
testMapStateWithNullUserKey(useAvro) { listState =>
listImplFunc match {
case "appendList" => listState.appendList(null)
case "put" => listState.put(null)
Expand All @@ -67,13 +69,14 @@ class ListStateSuite extends StateVariableSuiteBase {
}
}

test("List state operations for single instance") {
testWithAvroEnc("List state operations for single instance") { useAvro =>
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())

val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong)
val testState: ListState[Long] = handle.getListStateWithAvro[Long](
"testState", Encoders.scalaLong, useAvro)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")

// simple put and get test
Expand All @@ -95,14 +98,16 @@ class ListStateSuite extends StateVariableSuiteBase {
}
}

test("List state operations for multiple instance") {
testWithAvroEnc("List state operations for multiple instance") { useAvro =>
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())

val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong)
val testState2: ListState[Long] = handle.getListState[Long]("testState2", Encoders.scalaLong)
val testState1: ListState[Long] = handle.getListStateWithAvro[Long](
"testState1", Encoders.scalaLong, useAvro)
val testState2: ListState[Long] = handle.getListStateWithAvro[Long](
"testState2", Encoders.scalaLong, useAvro)

ImplicitGroupingKeyTracker.setImplicitKey("test_key")

Expand Down Expand Up @@ -133,16 +138,18 @@ class ListStateSuite extends StateVariableSuiteBase {
}
}

test("List state operations with list, value, another list instances") {
testWithAvroEnc("List state operations with list, value, another list instances") { useAvro =>
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())

val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong)
val listState2: ListState[Long] = handle.getListState[Long]("listState2", Encoders.scalaLong)
val valueState: ValueState[Long] = handle.getValueState[Long](
"valueState", Encoders.scalaLong)
val listState1: ListState[Long] = handle.getListStateWithAvro[Long](
"listState1", Encoders.scalaLong, useAvro)
val listState2: ListState[Long] = handle.getListStateWithAvro[Long](
"listState2", Encoders.scalaLong, useAvro)
val valueState: ValueState[Long] = handle.getValueStateWithAvro[Long](
"valueState", Encoders.scalaLong, useAvro = false)

ImplicitGroupingKeyTracker.setImplicitKey("test_key")
// simple put and get test
Expand Down Expand Up @@ -245,7 +252,7 @@ class ListStateSuite extends StateVariableSuiteBase {
}
}

test("ListState TTL with non-primitive types") {
testWithAvroEnc("ListState TTL with non-primitive types") { useAvro =>
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val timestampMs = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ class MapStateSuite extends StateVariableSuiteBase {
val mapTestState2: MapState[String, Int] =
handle.getMapState[String, Int]("mapTestState2", Encoders.STRING, Encoders.scalaInt)
val valueTestState: ValueState[String] =
handle.getValueState[String]("valueTestState", Encoders.STRING)
handle.getValueStateWithAvro[String]("valueTestState", Encoders.STRING, false)
val listTestState: ListState[String] =
handle.getListState[String]("listTestState", Encoders.STRING)
handle.getListStateWithAvro[String]("listTestState", Encoders.STRING, false)

ImplicitGroupingKeyTracker.setImplicitKey("test_key")
// put initial values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ trait AlsoTestWithChangelogCheckpointingEnabled
}
}

def testWithAvroEncoding(testName: String, testTags: Tag*)
(testBody: => Any): Unit = {
def testWithEncodingTypes(testName: String, testTags: Tag*)
(testBody: => Any): Unit = {
Seq("UnsafeRow", "Avro").foreach { encoding =>
super.test(testName + s" (encoding = $encoding)", testTags: _*) {
// in case tests have any code that needs to execute before every test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
Encoders.STRING, TTLConfig(Duration.ofHours(1)))

// create another state without TTL, this should not be captured in the handle
handle.getValueState("testState", Encoders.STRING)
handle.getValueStateWithAvro("testState", Encoders.STRING, useAvro = false)

assert(handle.ttlStates.size() === 1)
assert(handle.ttlStates.get(0) === valueStateWithTTL)
Expand Down Expand Up @@ -275,7 +275,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
val handle = new StatefulProcessorHandleImpl(store,
UUID.randomUUID(), stringEncoder, TimeMode.None())

handle.getValueState("testValueState", Encoders.STRING)
handle.getValueStateWithAvro("testValueState", Encoders.STRING, useAvro = false)
handle.getListState("testListState", Encoders.STRING)
handle.getMapState("testMapState", Encoders.STRING, Encoders.STRING)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class TransformWithListStateSuite extends StreamTest
with AlsoTestWithChangelogCheckpointingEnabled {
import testImplicits._

testWithAvroEncoding("test appending null value in list state throw exception") {
testWithEncodingTypes("test appending null value in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -149,7 +149,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithAvroEncoding("test putting null value in list state throw exception") {
testWithEncodingTypes("test putting null value in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -169,7 +169,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithAvroEncoding("test putting null list in list state throw exception") {
testWithEncodingTypes("test putting null list in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -189,7 +189,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithAvroEncoding("test appending null list in list state throw exception") {
testWithEncodingTypes("test appending null list in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -209,7 +209,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithAvroEncoding("test putting empty list in list state throw exception") {
testWithEncodingTypes("test putting empty list in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -229,7 +229,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithAvroEncoding("test appending empty list in list state throw exception") {
testWithEncodingTypes("test appending empty list in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -249,7 +249,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithAvroEncoding("test list state correctness") {
testWithEncodingTypes("test list state correctness") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand Down Expand Up @@ -307,7 +307,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithAvroEncoding("test ValueState And ListState in Processor") {
testWithEncodingTypes("test ValueState And ListState in Processor") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand Down
Loading