Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
a base change a draft suite
  • Loading branch information
jingz-db committed Jul 8, 2024
commit 4849f20db379e79b6126e23d6b9c7024730d25e0
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class StatefulProcessorHandleImpl(
isStreaming: Boolean = true,
batchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
extends StatefulProcessorHandleImplBase(timeMode) with Logging {
extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging {
import StatefulProcessorHandleState._

/**
Expand Down Expand Up @@ -297,8 +297,8 @@ class StatefulProcessorHandleImpl(
* actually done. We need this class because we can only collect the schemas after
* the StatefulProcessor is initialized.
*/
class DriverStatefulProcessorHandleImpl(timeMode: TimeMode)
extends StatefulProcessorHandleImplBase(timeMode) {
class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any])
extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) {

private[sql] val columnFamilySchemaUtils = ColumnFamilySchemaUtilsV1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
*/
package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical.NoTime
import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.{INITIALIZED, PRE_INIT, StatefulProcessorHandleState, TIMER_PROCESSED}
import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
import org.apache.spark.sql.streaming.{StatefulProcessorHandle, TimeMode}

abstract class StatefulProcessorHandleImplBase(timeMode: TimeMode)
abstract class StatefulProcessorHandleImplBase(
timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any])
extends StatefulProcessorHandle {

protected var currState: StatefulProcessorHandleState = PRE_INIT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ case class TransformWithStateExec(
* @return a new instance of the driver processor handle
*/
private def getDriverProcessorHandle: DriverStatefulProcessorHandleImpl = {
val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode)
val driverProcessorHandle = new DriverStatefulProcessorHandleImpl(timeMode, keyEncoder)
driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT)
statefulProcessor.setHandle(driverProcessorHandle)
statefulProcessor.init(outputMode, timeMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ package org.apache.spark.sql.streaming
import java.io.File
import java.util.UUID

import org.json4s.JsonAST.JString

import org.apache.spark.SparkRuntimeException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, Encoders}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.util.stringToFile
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
Expand Down Expand Up @@ -899,3 +902,20 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest {
)
}
}

class TransformWithStateSchemaSuite extends StateStoreMetricsTest {

test("schema") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
StateTypesEncoder(keySerializer = encoderFor(Encoders.scalaInt).createSerializer(),
valEncoder = Encoders.STRING, stateName = "someState", hasTtl = false)

val keyExprEncoderSer = encoderFor(Encoders.scalaInt).schema
println("keyExprEncoder here: " + JString(keyExprEncoderSer.json))
println("valueEncoder here: " + JString(Encoders.STRING.schema.json))
}
}
}