Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
trying to plumb schema through planning rule
  • Loading branch information
ericm-db committed Jun 13, 2024
commit 3554af22d3a1eb0b96702c26a91c691b3c9e21f2
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ class IncrementalExecution(
}
}

object PopulateSchemaV3Rule extends SparkPlanPartialRule with Logging {
logError(s"### PopulateSchemaV3Rule, batchId = $currentBatchId")
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case tws: TransformWithStateExec =>
val stateSchemaV3File = new StateSchemaV3File(
hadoopConf, tws.stateSchemaFilePath().toString)
logError(s"### trying to get schema from file: ${tws.stateSchemaFilePath()}")
stateSchemaV3File.getLatest() match {
case Some((_, schemaJValue)) =>
logError("### PASSING SCHEMA TO OPERATOR")
logError(s"### schemaJValue: $schemaJValue")
tws.copy(columnFamilyJValue = Some(schemaJValue))
case None => tws
}
}
}

object StateOpIdRule extends SparkPlanPartialRule {
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion,
Expand Down Expand Up @@ -454,16 +471,18 @@ class IncrementalExecution(
}

override def apply(plan: SparkPlan): SparkPlan = {
logError(s"### applying rules to plan")
val planWithStateOpId = plan transform composedRule
val planWithSchema = planWithStateOpId transform PopulateSchemaV3Rule.rule
// Need to check before write to metadata because we need to detect add operator
// Only check when streaming is restarting and is first batch
if (isFirstBatch && currentBatchId != 0) {
checkOperatorValidWithMetadata(planWithStateOpId)
checkOperatorValidWithMetadata(planWithSchema)
}
// The rule doesn't change the plan but cause the side effect that metadata is written
// in the checkpoint directory of stateful operator.
simulateWatermarkPropagation(planWithStateOpId)
planWithStateOpId transform WatermarkPropagationRule.rule
simulateWatermarkPropagation(planWithSchema)
planWithSchema transform WatermarkPropagationRule.rule
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.ListState

/**
Expand All @@ -44,8 +44,9 @@ class ListStateImpl[S](

private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName)

store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true)
val columnFamilySchema = new ColumnFamilySchemaV1(
stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false)
store.createColFamilyIfAbsent(columnFamilySchema)

/** Whether state exists or not. */
override def exists(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -52,11 +52,13 @@ class ListStateImplWithTTL[S](
private lazy val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

val columnFamilySchema = new ColumnFamilySchemaV1(
stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true)
initialize()

private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true)
store.createColFamilyIfAbsent(columnFamilySchema)
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.streaming.MapState
import org.apache.spark.sql.types.{BinaryType, StructType}

class MapStateImpl[K, V](
store: StateStore,
Expand All @@ -30,18 +30,15 @@ class MapStateImpl[K, V](
userKeyEnc: Encoder[K],
valEncoder: Encoder[V]) extends MapState[K, V] with Logging {

// Pack grouping key and user key together as a prefixed composite key
private val schemaForCompositeKeyRow: StructType =
new StructType()
.add("key", BinaryType)
.add("userKey", BinaryType)
private val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
private val keySerializer = keyExprEnc.createSerializer()
private val stateTypesEncoder = new CompositeKeyStateEncoder(
keySerializer, userKeyEnc, valEncoder, schemaForCompositeKeyRow, stateName)
keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName)

store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, schemaForValueRow,
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1))
val columnFamilySchema = new ColumnFamilySchemaV1(
stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false)

store.createColFamilyIfAbsent(columnFamilySchema)

/** Whether state exists or not. */
override def exists(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{MapState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -55,11 +55,13 @@ class MapStateImplWithTTL[K, V](
private val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

val columnFamilySchema = new ColumnFamilySchemaV1(
stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false)
initialize()

private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
store.createColFamilyIfAbsent(columnFamilySchema)
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,10 +906,22 @@ class MicroBatchExecution(
val shouldWriteMetadatas = execCtx.previousContext match {
case Some(prevCtx)
if prevCtx.executionPlan.runId == execCtx.executionPlan.runId =>
false
false
case _ => true
}

if (shouldWriteMetadatas) {
execCtx.executionPlan.executedPlan.collect {
case tws: TransformWithStateExec =>
val schema = tws.getColumnFamilyJValue()
val metadata = tws.operatorStateMetadata()
val id = metadata.operatorInfo.operatorId
val schemaFile = stateSchemaLogs(id)
logError(s"Writing schema for operator $id at path ${schemaFile.metadataPath}")
if (!schemaFile.add(execCtx.batchId, schema)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
}
execCtx.executionPlan.executedPlan.collect {
case s: StateStoreWriter =>
val metadata = s.operatorStateMetadata()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import java.io.{InputStream, OutputStream, StringReader}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream}
import org.json4s.JValue
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf

class StateSchemaV3File(
hadoopConf: Configuration,
path: String,
metadataCacheEnabled: Boolean = false)
extends HDFSMetadataLog[JValue](hadoopConf, path, metadataCacheEnabled) {

final val MAX_UTF_CHUNK_SIZE = 65535
def this(sparkSession: SparkSession, path: String) = {
this(
sparkSession.sessionState.newHadoopConf(),
path,
metadataCacheEnabled = sparkSession.sessionState.conf.getConf(
SQLConf.STREAMING_METADATA_CACHE_ENABLED)
)
}

override protected def serialize(schema: JValue, out: OutputStream): Unit = {
val json = compact(render(schema))
val buf = new Array[Char](MAX_UTF_CHUNK_SIZE)

val outputStream = out.asInstanceOf[FSDataOutputStream]
// DataOutputStream.writeUTF can't write a string at once
// if the size exceeds 65535 (2^16 - 1) bytes.
// Each metadata consists of multiple chunks in schema version 3.
try {
val numMetadataChunks = (json.length - 1) / MAX_UTF_CHUNK_SIZE + 1
val metadataStringReader = new StringReader(json)
outputStream.writeInt(numMetadataChunks)
(0 until numMetadataChunks).foreach { _ =>
val numRead = metadataStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE)
outputStream.writeUTF(new String(buf, 0, numRead))
}
outputStream.close()
} catch {
case e: Throwable =>
throw e
}
}

override protected def deserialize(in: InputStream): JValue = {
val buf = new StringBuilder
val inputStream = in.asInstanceOf[FSDataInputStream]
val numKeyChunks = inputStream.readInt()
(0 until numKeyChunks).foreach(_ => buf.append(inputStream.readUTF()))
val json = buf.toString()
JsonMethods.parse(json)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class StatefulProcessorHandleImpl(
private[sql] val stateVariables: util.List[StateVariableInfo] =
new util.ArrayList[StateVariableInfo]()

private[sql] val columnFamilySchemas: util.List[ColumnFamilySchema] =
new util.ArrayList[ColumnFamilySchema]()

private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000"

private def buildQueryInfo(): QueryInfo = {
Expand Down Expand Up @@ -139,6 +142,8 @@ class StatefulProcessorHandleImpl(
new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder)
case None =>
stateVariables.add(new StateVariableInfo(stateName, ValueState, false))
val colFamilySchema = resultState.columnFamilySchema
columnFamilySchemas.add(colFamilySchema)
null
}
}
Expand All @@ -158,6 +163,8 @@ class StatefulProcessorHandleImpl(
valueStateWithTTL
case None =>
stateVariables.add(new StateVariableInfo(stateName, ValueState, true))
val colFamilySchema = resultState.columnFamilySchema
columnFamilySchemas.add(colFamilySchema)
null
}
}
Expand Down Expand Up @@ -296,6 +303,8 @@ class StatefulProcessorHandleImpl(
listStateWithTTL
case None =>
stateVariables.add(new StateVariableInfo(stateName, ListState, true))
val colFamilySchema = resultState.columnFamilySchema
columnFamilySchemas.add(colFamilySchema)
null
}
}
Expand All @@ -311,6 +320,8 @@ class StatefulProcessorHandleImpl(
new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder)
case None =>
stateVariables.add(new StateVariableInfo(stateName, ValueState, false))
val colFamilySchema = resultState.columnFamilySchema
columnFamilySchemas.add(colFamilySchema)
null
}
}
Expand All @@ -331,6 +342,8 @@ class StatefulProcessorHandleImpl(
mapStateWithTTL
case None =>
stateVariables.add(new StateVariableInfo(stateName, MapState, true))
val colFamilySchema = resultState.columnFamilySchema
columnFamilySchemas.add(colFamilySchema)
null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ abstract class StreamExecution(
populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan)
}

lazy val stateSchemaLogs: Map[Long, StateSchemaV3File] = {
populateStateSchemaFiles(getLatestExecutionContext().executionPlan.executedPlan)
}

private def populateOperatorStateMetadatas(
plan: SparkPlan): Map[Long, OperatorStateMetadataLog] = {
plan.flatMap {
Expand All @@ -256,6 +260,18 @@ abstract class StreamExecution(
}.toMap
}

private def populateStateSchemaFiles(
plan: SparkPlan): Map[Long, StateSchemaV3File] = {
plan.flatMap {
case s: StateStoreWriter => s.stateInfo.map { info =>
val schemaFilePath = s.stateSchemaFilePath()
info.operatorId -> new StateSchemaV3File(sparkSession,
schemaFilePath.toString)
}
case _ => Seq.empty
}.toMap
}

/** Whether all fields of the query have been initialized */
private def isInitialized: Boolean = state.get != INITIALIZING

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ case class TransformWithStateExec(
initialStateGroupingAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialStateDeserializer: Expression,
initialState: SparkPlan)
initialState: SparkPlan,
columnFamilyJValue: Option[JValue] = None)
extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec {

val operatorProperties: util.Map[String, JValue] =
Expand All @@ -91,6 +92,7 @@ case class TransformWithStateExec(

override def shortName: String = "transformWithStateExec"

columnFamilySchemas()

/** Metadata of this stateful operator and its states stores. */
override def operatorStateMetadata(): OperatorStateMetadata = {
Expand All @@ -107,6 +109,20 @@ case class TransformWithStateExec(
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
}

def getColumnFamilyJValue(): JValue = {
val columnFamilySchemas = operatorProperties.get("columnFamilySchemas")
columnFamilySchemas
}

def columnFamilySchemas(): List[ColumnFamilySchema] = {
val columnFamilySchemas = ColumnFamilySchemaV1.fromJValue(columnFamilyJValue)
columnFamilySchemas.foreach {
case c1: ColumnFamilySchemaV1 => logError(s"### colFamilyName:" +
s"${c1.columnFamilyName}")
}
columnFamilySchemas
}

override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
if (timeMode == ProcessingTime) {
// TODO: check if we can return true only if actual timers are registered, or there is
Expand Down Expand Up @@ -382,6 +398,9 @@ case class TransformWithStateExec(
statefulProcessor.init(outputMode, timeMode)
operatorProperties.put("stateVariables", JArray(driverProcessorHandle.stateVariables.
asScala.map(_.jsonValue).toList))
operatorProperties.put("columnFamilySchemas", JArray(driverProcessorHandle.
columnFamilySchemas.asScala.map(_.jsonValue).toList))

statefulProcessor.setHandle(null)
driverProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)

Expand Down
Loading