Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
adding OperatorStateMetadataLog
  • Loading branch information
ericm-db committed Jun 6, 2024
commit d9d9d995573070c58c0221d634918f3cc3e9e96a
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1}
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, OperatorStateMetadataLog}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -55,7 +55,8 @@ case class StateMetadataTableEntry(
numPartitions,
minBatchId,
maxBatchId,
numColsPrefixKey))
numColsPrefixKey
))
}
}

Expand Down Expand Up @@ -193,8 +194,11 @@ class StateMetadataPartitionReader(
val opIds = fileManager
.list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted
opIds.map { opId =>
new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read()
}
val dirLocation = new Path(stateDir, opId.toString)
val metadataFilePath = OperatorStateMetadata.metadataFilePath(dirLocation)
val log = new OperatorStateMetadataLog(SparkSession.active, metadataFilePath.toString)
log.getLatest()
}.filter(_.isDefined).map(_.get._2)
}

private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataWriter}
import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataV1
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand Down Expand Up @@ -187,17 +187,6 @@ class IncrementalExecution(
}
}

object WriteStatefulOperatorMetadataRule extends SparkPlanPartialRule {
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case stateStoreWriter: StateStoreWriter if isFirstBatch =>
val metadata = stateStoreWriter.operatorStateMetadata()
val metadataWriter = new OperatorStateMetadataWriter(new Path(
checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf)
metadataWriter.write(metadata)
stateStoreWriter
}
}

object StateOpIdRule extends SparkPlanPartialRule {
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion,
Expand Down Expand Up @@ -473,7 +462,6 @@ class IncrementalExecution(
}
// The rule doesn't change the plan but cause the side effect that metadata is written
// in the checkpoint directory of stateful operator.
planWithStateOpId transform WriteStatefulOperatorMetadataRule.rule
simulateWatermarkPropagation(planWithStateOpId)
planWithStateOpId transform WatermarkPropagationRule.rule
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsTriggerAvailableNow}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1}
Expand Down Expand Up @@ -88,6 +88,22 @@ class MicroBatchExecution(

@volatile protected[sql] var triggerExecutor: TriggerExecutor = _

private lazy val operatorStateMetadatas: Map[Long, OperatorStateMetadataLog] = {
populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan)
}

private def populateOperatorStateMetadatas(plan: SparkPlan):
Map[Long, OperatorStateMetadataLog] = {
plan.flatMap {
case s: StateStoreWriter => s.stateInfo.map { info =>
val metadataPath = s.metadataFilePath()
info.operatorId -> new OperatorStateMetadataLog(sparkSession,
metadataPath.toString)
}
case _ => Seq.empty
}.toMap
}

protected def getTrigger(): TriggerExecutor = {
assert(sources.nonEmpty, "sources should have been retrieved from the plan!")
trigger match {
Expand Down Expand Up @@ -902,6 +918,15 @@ class MicroBatchExecution(
if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
execCtx.executionPlan.executedPlan.collect {
case s: StateStoreWriter =>
val metadata = s.operatorStateMetadata()
val id = metadata.operatorInfo.operatorId
val metadataFile = operatorStateMetadatas(id)
if (!metadataFile.add(execCtx.batchId, metadata)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
}
}
committedOffsets ++= execCtx.endOffsets
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
import java.nio.charset.StandardCharsets
import java.nio.charset.StandardCharsets._

import org.apache.hadoop.fs.FSDataOutputStream

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2}


class OperatorStateMetadataLog(sparkSession: SparkSession, path: String)
extends HDFSMetadataLog[OperatorStateMetadata](sparkSession, path) {
override protected def serialize(metadata: OperatorStateMetadata, out: OutputStream): Unit = {
val fsDataOutputStream = out.asInstanceOf[FSDataOutputStream]
fsDataOutputStream.write(s"v${metadata.version}\n".getBytes(StandardCharsets.UTF_8))
metadata.version match {
case 1 =>
OperatorStateMetadataV1.serialize(fsDataOutputStream, metadata)
case 2 =>
OperatorStateMetadataV2.serialize(fsDataOutputStream, metadata)
}
}

override protected def deserialize(in: InputStream): OperatorStateMetadata = {
// called inside a try-finally where the underlying stream is closed in the caller
// create buffered reader from input stream
val bufferedReader = new BufferedReader(new InputStreamReader(in, UTF_8))
// read first line for version number, in the format "v{version}"
val version = bufferedReader.readLine()
version match {
case "v1" => OperatorStateMetadataV1.deserialize(bufferedReader)
case "v2" => OperatorStateMetadataV2.deserialize(bufferedReader)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.json4s.{Formats, NoTypeHints}
import org.json4s.JsonAST.JValue
import org.json4s.jackson.Serialization

import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil}
import org.apache.spark.util.AccumulatorV2

/**
* Metadata for a state store instance.
Expand All @@ -54,6 +57,15 @@ case class OperatorInfoV1(operatorId: Long, operatorName: String) extends Operat

trait OperatorStateMetadata {
def version: Int

def operatorInfo: OperatorInfo

def stateStoreInfo: Array[StateStoreMetadataV1]
}

object OperatorStateMetadata {
def metadataFilePath(stateCheckpointPath: Path): Path =
new Path(new Path(stateCheckpointPath, "_metadata"), "metadata")
}

case class OperatorStateMetadataV1(
Expand All @@ -62,6 +74,56 @@ case class OperatorStateMetadataV1(
override def version: Int = 1
}

/**
* Accumulator to store arbitrary Operator properties.
* This accumulator is used to store the properties of an operator that are not
* available on the driver at the time of planning, and will only be known from
* the executor side.
*/
class OperatorProperties(initValue: Map[String, JValue] = Map.empty)
extends AccumulatorV2[Map[String, JValue], Map[String, JValue]] {

private var _value: Map[String, JValue] = initValue

override def isZero: Boolean = _value.isEmpty

override def copy(): AccumulatorV2[Map[String, JValue], Map[String, JValue]] = {
val newAcc = new OperatorProperties
newAcc._value = _value
newAcc
}

override def reset(): Unit = _value = Map.empty[String, JValue]

override def add(v: Map[String, JValue]): Unit = _value ++= v

override def merge(other: AccumulatorV2[Map[String, JValue], Map[String, JValue]]): Unit = {
_value ++= other.value
}

override def value: Map[String, JValue] = _value
}

object OperatorProperties {
def create(
sc: SparkContext,
name: String,
initValue: Map[String, JValue] = Map.empty): OperatorProperties = {
val acc = new OperatorProperties(initValue)
acc.register(sc, name = Some(name))
acc
}
}

// operatorProperties is an arbitrary JSON formatted string that contains
// any properties that we would want to store for a particular operator.
case class OperatorStateMetadataV2(
operatorInfo: OperatorInfoV1,
stateStoreInfo: Array[StateStoreMetadataV1],
operatorPropertiesJson: String) extends OperatorStateMetadata {
override def version: Int = 2
}

object OperatorStateMetadataV1 {

private implicit val formats: Formats = Serialization.formats(NoTypeHints)
Expand All @@ -70,9 +132,6 @@ object OperatorStateMetadataV1 {
private implicit val manifest = Manifest
.classType[OperatorStateMetadataV1](implicitly[ClassTag[OperatorStateMetadataV1]].runtimeClass)

def metadataFilePath(stateCheckpointPath: Path): Path =
new Path(new Path(stateCheckpointPath, "_metadata"), "metadata")

def deserialize(in: BufferedReader): OperatorStateMetadata = {
Serialization.read[OperatorStateMetadataV1](in)
}
Expand All @@ -84,13 +143,31 @@ object OperatorStateMetadataV1 {
}
}

object OperatorStateMetadataV2 {
private implicit val formats: Formats = Serialization.formats(NoTypeHints)

@scala.annotation.nowarn
private implicit val manifest = Manifest
.classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass)

def deserialize(in: BufferedReader): OperatorStateMetadata = {
Serialization.read[OperatorStateMetadataV2](in)
}

def serialize(
out: FSDataOutputStream,
operatorStateMetadata: OperatorStateMetadata): Unit = {
Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV2], out)
}
}

/**
* Write OperatorStateMetadata into the state checkpoint directory.
*/
class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configuration)
extends Logging {

private val metadataFilePath = OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath)
private val metadataFilePath = OperatorStateMetadata.metadataFilePath(stateCheckpointPath)

private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf)

Expand All @@ -101,7 +178,12 @@ class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configu
val outputStream = fm.createAtomic(metadataFilePath, overwriteIfPossible = false)
try {
outputStream.write(s"v${operatorMetadata.version}\n".getBytes(StandardCharsets.UTF_8))
OperatorStateMetadataV1.serialize(outputStream, operatorMetadata)
operatorMetadata.version match {
case 1 =>
OperatorStateMetadataV1.serialize(outputStream, operatorMetadata)
case 2 =>
OperatorStateMetadataV2.serialize(outputStream, operatorMetadata)
}
outputStream.close()
} catch {
case e: Throwable =>
Expand All @@ -117,7 +199,7 @@ class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configu
*/
class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: Configuration) {

private val metadataFilePath = OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath)
private val metadataFilePath = OperatorStateMetadata.metadataFilePath(stateCheckpointPath)

private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf)

Expand All @@ -127,9 +209,12 @@ class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: Configu
new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))
try {
val versionStr = inputReader.readLine()
val version = MetadataVersionUtil.validateVersion(versionStr, 1)
assert(version == 1)
OperatorStateMetadataV1.deserialize(inputReader)
val version = MetadataVersionUtil.validateVersion(versionStr, 2)
assert(version == 1 || version == 2)
version match {
case 1 => OperatorStateMetadataV1.deserialize(inputReader)
case 2 => OperatorStateMetadataV2.deserialize(inputReader)
}
} finally {
inputStream.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import java.util.concurrent.TimeUnit._
import scala.collection.mutable
import scala.jdk.CollectionConverters._

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -70,6 +72,12 @@ trait StatefulOperator extends SparkPlan {
throw new IllegalStateException("State location not present for execution")
}
}

def metadataFilePath(): Path = {
val stateCheckpointPath =
new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString)
new Path(new Path(stateCheckpointPath, "_metadata"), "metadata")
}
}

/**
Expand Down