Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
writing schema and metadata in planning rule
  • Loading branch information
ericm-db committed Jun 14, 2024
commit 52a579e0842fc9cb5e33e9511eb106a8a1fa8058
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.UUID
import java.util.concurrent.atomic.AtomicInteger

import org.apache.hadoop.fs.Path
import org.json4s.JsonAST.JValue

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{BATCH_TIMESTAMP, ERROR}
Expand All @@ -37,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand Down Expand Up @@ -187,15 +188,48 @@ class IncrementalExecution(
}
}

def writeSchemaAndMetadataFiles(
stateSchemaV3File: StateSchemaV3File,
operatorStateMetadataLog: OperatorStateMetadataLog,
stateSchema: JValue,
operatorStateMetadata: OperatorStateMetadata): Unit = {
operatorStateMetadataLog.purgeAfter(currentBatchId - 1)
if (!stateSchemaV3File.add(currentBatchId, stateSchema)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId)
}
if (!operatorStateMetadataLog.add(currentBatchId, operatorStateMetadata)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId)
}
}

object PopulateSchemaV3Rule extends SparkPlanPartialRule with Logging {
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case tws: TransformWithStateExec if isFirstBatch && currentBatchId != 0 =>
case tws: TransformWithStateExec if isFirstBatch =>
val stateSchemaV3File = new StateSchemaV3File(
hadoopConf, tws.stateSchemaFilePath().toString)
val operatorStateMetadataLog = new OperatorStateMetadataLog(
hadoopConf,
tws.metadataFilePath().toString
)
stateSchemaV3File.getLatest() match {
case Some((_, schemaJValue)) =>
tws.copy(columnFamilyJValue = Some(schemaJValue))
case None => tws
case Some((_, oldSchema)) =>
val newSchema = tws.getSchema()
tws.compareSchemas(oldSchema, newSchema)
writeSchemaAndMetadataFiles(
stateSchemaV3File = stateSchemaV3File,
operatorStateMetadataLog = operatorStateMetadataLog,
stateSchema = newSchema,
operatorStateMetadata = tws.operatorStateMetadata()
)
tws.copy(columnFamilyJValue = Some(oldSchema))
case None =>
writeSchemaAndMetadataFiles(
stateSchemaV3File = stateSchemaV3File,
operatorStateMetadataLog = operatorStateMetadataLog,
stateSchema = tws.getSchema(),
operatorStateMetadata = tws.operatorStateMetadata()
)
tws
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,43 +899,6 @@ class MicroBatchExecution(
*/
protected def markMicroBatchEnd(execCtx: MicroBatchExecutionContext): Unit = {
watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan)
val shouldWriteMetadatas = execCtx.previousContext match {
case Some(prevCtx)
if prevCtx.executionPlan.runId == execCtx.executionPlan.runId =>
false
case _ => true
}

if (shouldWriteMetadatas) {
// clean up any batchIds that are greater than or equal to
// the current batchId
execCtx.executionPlan.executedPlan.collect {
case tws: TransformWithStateExec =>
val metadata = tws.operatorStateMetadata()
val id = metadata.operatorInfo.operatorId
val metadataFile = operatorStateMetadataLogs(id)
metadataFile.purgeAfter(execCtx.batchId - 1)
}
execCtx.executionPlan.executedPlan.collect {
case tws: TransformWithStateExec =>
val metadata = tws.operatorStateMetadata()
val id = metadata.operatorInfo.operatorId
val schemaFile = stateSchemaLogs(id)
val schema = tws.getColumnFamilyJValue()
if (!schemaFile.add(execCtx.batchId, schema)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
}
execCtx.executionPlan.executedPlan.collect {
case s: StateStoreWriter =>
val metadata = s.operatorStateMetadata()
val id = metadata.operatorInfo.operatorId
val metadataFile = operatorStateMetadataLogs(id)
if (!metadataFile.add(execCtx.batchId, metadata)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
}
}
execCtx.reportTimeTaken("commitOffsets") {
if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,6 @@ class StatefulProcessorHandleImpl(

def getHandleState: StatefulProcessorHandleState = currState

private def verifyStateVariableCreation(columnFamilySchema: ColumnFamilySchema): Unit = {
columnFamilySchema match {
case c1: ColumnFamilySchemaV1 if existingColumnFamilies.contains(c1.columnFamilyName) =>
val existingColumnFamily = existingColumnFamilies(c1.columnFamilyName)
if (existingColumnFamily.json != columnFamilySchema.json) {
throw new RuntimeException(
s"State variable with name ${c1.columnFamilyName} already exists " +
s"with different schema.")
}
case _ =>
}
}

override def getValueState[T](
stateName: String,
valEncoder: Encoder[T]): ValueState[T] = {
Expand All @@ -157,7 +144,6 @@ class StatefulProcessorHandleImpl(
case None =>
stateVariables.add(new StateVariableInfo(stateName, ValueState, false))
val colFamilySchema = ValueStateImpl.columnFamilySchema(stateName)
verifyStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(colFamilySchema)
null
}
Expand All @@ -179,7 +165,6 @@ class StatefulProcessorHandleImpl(
case None =>
stateVariables.add(new StateVariableInfo(stateName, ValueState, true))
val colFamilySchema = ValueStateImplWithTTL.columnFamilySchema(stateName)
verifyStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(colFamilySchema)
null
}
Expand Down Expand Up @@ -287,7 +272,6 @@ class StatefulProcessorHandleImpl(
case None =>
stateVariables.add(new StateVariableInfo(stateName, ListState, false))
val colFamilySchema = ListStateImpl.columnFamilySchema(stateName)
verifyStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(colFamilySchema)
null
}
Expand Down Expand Up @@ -325,7 +309,6 @@ class StatefulProcessorHandleImpl(
case None =>
stateVariables.add(new StateVariableInfo(stateName, ListState, true))
val colFamilySchema = ListStateImplWithTTL.columnFamilySchema(stateName)
verifyStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(colFamilySchema)
null
}
Expand All @@ -343,7 +326,6 @@ class StatefulProcessorHandleImpl(
case None =>
stateVariables.add(new StateVariableInfo(stateName, ValueState, false))
val colFamilySchema = MapStateImpl.columnFamilySchema(stateName)
verifyStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(colFamilySchema)
null
}
Expand All @@ -366,7 +348,6 @@ class StatefulProcessorHandleImpl(
case None =>
stateVariables.add(new StateVariableInfo(stateName, MapState, true))
val colFamilySchema = MapStateImplWithTTL.columnFamilySchema(stateName)
verifyStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(colFamilySchema)
null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,30 +92,51 @@ case class TransformWithStateExec(

override def shortName: String = "transformWithStateExec"

columnFamilySchemas()

/** Metadata of this stateful operator and its states stores. */
override def operatorStateMetadata(): OperatorStateMetadata = {
val info = getStateInfo
val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
val stateStoreInfo =
Array(StateStoreMetadataV1(StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions))

val driverProcessorHandle = getDriverProcessorHandle
val stateVariables = JArray(driverProcessorHandle.stateVariables.
asScala.map(_.jsonValue).toList)

closeProcessorHandle(driverProcessorHandle)
val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~
("outputMode" -> JString(outputMode.toString)) ~
("stateVariables" -> operatorProperties.get("stateVariables"))
("stateVariables" -> stateVariables)

val json = compact(render(operatorPropertiesJson))
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
}

def getColumnFamilyJValue(): JValue = {
val columnFamilySchemas = operatorProperties.get("columnFamilySchemas")
def getSchema(): JValue = {
val driverProcessorHandle = getDriverProcessorHandle
val columnFamilySchemas = JArray(driverProcessorHandle.
columnFamilySchemas.asScala.map(_.jsonValue).toList)
closeProcessorHandle(driverProcessorHandle)
columnFamilySchemas
}

def columnFamilySchemas(): List[ColumnFamilySchema] = {
ColumnFamilySchemaV1.fromJValue(columnFamilyJValue)
def compareSchemas(oldSchema: JValue, newSchema: JValue): Unit = {
val oldColumnFamilies = ColumnFamilySchemaV1.fromJValue(oldSchema)
val newColumnFamilies = ColumnFamilySchemaV1.fromJValue(newSchema).map {
case c1: ColumnFamilySchemaV1 =>
c1.columnFamilyName -> c1
}.toMap

oldColumnFamilies.foreach {
case oldColumnFamily: ColumnFamilySchemaV1 =>
newColumnFamilies.get(oldColumnFamily.columnFamilyName) match {
case Some(newColumnFamily) if oldColumnFamily.json != newColumnFamily.json =>
throw new RuntimeException(
s"State variable with name ${newColumnFamily.columnFamilyName}" +
s" already exists with different schema.")
case _ => // do nothing
}
}
}

override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
Expand Down Expand Up @@ -379,30 +400,26 @@ case class TransformWithStateExec(
)
}

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

validateTimeMode()

val existingColumnFamilies = columnFamilySchemas().map {
case c1: ColumnFamilySchemaV1 =>
c1.columnFamilyName -> c1
}.toMap

protected def getDriverProcessorHandle: StatefulProcessorHandleImpl = {
val driverProcessorHandle = new StatefulProcessorHandleImpl(
None, getStateInfo.queryRunId, keyEncoder, timeMode,
isStreaming, batchTimestampMs, metrics, existingColumnFamilies)

isStreaming, batchTimestampMs, metrics)
driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT)
statefulProcessor.setHandle(driverProcessorHandle)
statefulProcessor.init(outputMode, timeMode)
operatorProperties.put("stateVariables", JArray(driverProcessorHandle.stateVariables.
asScala.map(_.jsonValue).toList))
operatorProperties.put("columnFamilySchemas", JArray(driverProcessorHandle.
columnFamilySchemas.asScala.map(_.jsonValue).toList))
driverProcessorHandle
}

protected def closeProcessorHandle(processorHandle: StatefulProcessorHandleImpl): Unit = {
statefulProcessor.close()
statefulProcessor.setHandle(null)
driverProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
processorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
}

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

validateTimeMode()

if (hasInitialState) {
val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
Expand Down