Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types._

object UnsafeRowUtils {

/**
* Use the following rules to check the integrity of the UnsafeRow:
* - schema.fields.length == row.numFields should always be true
* - UnsafeRow.calculateBitSetWidthInBytes(row.numFields) < row.getSizeInBytes should always be
* true if the expectedSchema contains at least one field.
* - For variable-length fields: if null bit says it's null then don't do anything, else extract
* offset and size:
* 1) 0 <= size < row.getSizeInBytes should always be true. We can be even more precise than
* this, where the upper bound of size can only be as big as the variable length part of
* the row.
* 2) offset should be >= fixed sized part of the row.
* 3) offset + size should be within the row bounds.
* - For fixed-length fields that are narrower than 8 bytes (boolean/byte/short/int/float), if
* null bit says it's null then don't do anything, else:
* check if the unused bits in the field are all zeros. The UnsafeRowWriter's write() methods
* make this guarantee.
* - Check the total length of the row.
*/
def validateStructuralIntegrity(row: UnsafeRow, expectedSchema: StructType): Boolean = {
if (expectedSchema.fields.length != row.numFields) {
return false
}
val bitSetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields)
val rowSizeInBytes = row.getSizeInBytes
if (expectedSchema.fields.length > 0 && bitSetWidthInBytes >= rowSizeInBytes) {
return false
}
var varLenFieldsSizeInBytes = 0
expectedSchema.fields.zipWithIndex.foreach {
case (field, index) if !UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) =>
val offsetAndSize = row.getLong(index)
val offset = (offsetAndSize >> 32).toInt
val size = offsetAndSize.toInt
if (size < 0 ||
offset < bitSetWidthInBytes + 8 * row.numFields || offset + size > rowSizeInBytes) {
return false
}
varLenFieldsSizeInBytes += size
case (field, index) if UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) =>
field.dataType match {
case BooleanType =>
if ((row.getLong(index) >> 1) != 0L) return false
case ByteType =>
if ((row.getLong(index) >> 8) != 0L) return false
case ShortType =>
if ((row.getLong(index) >> 16) != 0L) return false
case IntegerType =>
if ((row.getLong(index) >> 32) != 0L) return false
case FloatType =>
if ((row.getLong(index) >> 32) != 0L) return false
case _ =>
}
case (_, index) if row.isNullAt(index) =>
if (row.getLong(index) != 0L) return false
case _ =>
}
if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) {
return false
}
true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,16 @@ object SQLConf {
.intConf
.createWithDefault(10)

val STATE_STORE_FORMAT_VALIDATION_ENABLED =
buildConf("spark.sql.streaming.stateStore.formatValidation.enabled")
.internal()
.doc("When true, check if the UnsafeRow from the state store is valid or not when running " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change UnsafeRow to checkpoint ? Most end users do not know what are UnsafeRow

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will submit a follow-up PR today.

"streaming queries. This can happen if the state store format has been changed. Note, " +
"the feature is only effective in the build-in HDFS state store provider now.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
.internal()
Expand Down Expand Up @@ -1543,18 +1553,18 @@ object SQLConf {

val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS =
buildConf("spark.sql.streaming.checkpointFileManagerClass")
.internal()
.doc("The class used to write checkpoint files atomically. This class must be a subclass " +
"of the interface CheckpointFileManager.")
.version("2.4.0")
.internal()
.stringConf

val STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED =
buildConf("spark.sql.streaming.checkpoint.escapedPathCheck.enabled")
.internal()
.doc("Whether to detect a streaming query may pick up an incorrect checkpoint path due " +
"to SPARK-26824.")
.version("3.0.0")
.internal()
.booleanConf
.createWithDefault(true)

Expand Down Expand Up @@ -2746,6 +2756,8 @@ class SQLConf extends Serializable with Logging {

def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)

def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)

def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)

def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class UnsafeRowUtilsSuite extends SparkFunSuite {

val testKeys: Seq[String] = Seq("key1", "key2")
val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)")

val testOutputSchema: StructType = StructType(
testKeys.map(createIntegerField) ++ testValues.map(createIntegerField))

val testRow: UnsafeRow = {
val unsafeRowProjection = UnsafeProjection.create(testOutputSchema)
val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema))
(testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) }
row
}

private def createIntegerField(name: String): StructField = {
StructField(name, IntegerType, nullable = false)
}

test("UnsafeRow format invalidation") {
// Pass the checking
UnsafeRowUtils.validateStructuralIntegrity(testRow, testOutputSchema)
// Fail for fields number not match
assert(!UnsafeRowUtils.validateStructuralIntegrity(
testRow, StructType(testKeys.map(createIntegerField))))
// Fail for invalid schema
val invalidSchema = StructType(testKeys.map(createIntegerField) ++
Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true),
StructField("value2", IntegerType, false)))
assert(!UnsafeRowUtils.validateStructuralIntegrity(testRow, invalidSchema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
@volatile private var storeConf: StateStoreConf = _
@volatile private var hadoopConf: Configuration = _
@volatile private var numberOfVersionsToRetainInMemory: Int = _
// TODO: The validation should be moved to a higher level so that it works for all state store
// implementations
@volatile private var isValidated = false
Copy link
Contributor

@cloud-fan cloud-fan Jun 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a TODO that this validation should be moved to a higher level so that it works for all state store implementations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, add the TODO in fd74ff9.


private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse)
private lazy val baseDir = stateStoreId.storeCheckpointLocation()
Expand Down Expand Up @@ -457,6 +460,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
// Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
if (!isValidated) {
StateStoreProvider.validateStateRowFormat(
keyRow, keySchema, valueRow, valueSchema, storeConf)
isValidated = true
}
map.put(keyRow, valueRow)
}
}
Expand Down Expand Up @@ -551,6 +559,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
// Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
if (!isValidated) {
StateStoreProvider.validateStateRowFormat(
keyRow, keySchema, valueRow, valueSchema, storeConf)
isValidated = true
}
map.put(keyRow, valueRow)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.{SparkContext, SparkEnv, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ThreadUtils, Utils}
Expand Down Expand Up @@ -143,6 +144,16 @@ case class StateStoreCustomSumMetric(name: String, desc: String) extends StateSt
case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric

/**
* An exception thrown when an invalid UnsafeRow is detected in state store.
*/
class InvalidUnsafeRowException
extends RuntimeException("The streaming query failed by state format invalidation. " +
"The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " +
"incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " +
"among restart. For the first case, you can try to restart the application without " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first case: I think it's for the cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resolution is for the first case. For the rest cases listing, they should be considered as user problems.

"checkpoint or use the legacy Spark version to process the streaming state.", null)

/**
* Trait representing a provider that provide [[StateStore]] instances representing
* versions of state data.
Expand Down Expand Up @@ -230,6 +241,26 @@ object StateStoreProvider {
provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
provider
}

/**
* Use the expected schema to check whether the UnsafeRow is valid.
*/
def validateStateRowFormat(
keyRow: UnsafeRow,
keySchema: StructType,
valueRow: UnsafeRow,
valueSchema: StructType,
conf: StateStoreConf): Unit = {
if (conf.formatValidationEnabled) {
if (!UnsafeRowUtils.validateStructuralIntegrity(keyRow, keySchema)) {
throw new InvalidUnsafeRowException
}
if (conf.formatValidationCheckValue &&
!UnsafeRowUtils.validateStructuralIntegrity(valueRow, valueSchema)) {
throw new InvalidUnsafeRowException
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.streaming.state
import org.apache.spark.sql.internal.SQLConf

/** A class that contains configuration parameters for [[StateStore]]s. */
class StateStoreConf(@transient private val sqlConf: SQLConf)
class StateStoreConf(
@transient private val sqlConf: SQLConf,
extraOptions: Map[String, String] = Map.empty)
extends Serializable {

def this() = this(new SQLConf)
Expand All @@ -43,14 +45,25 @@ class StateStoreConf(@transient private val sqlConf: SQLConf)
*/
val providerClass: String = sqlConf.stateStoreProviderClass

/** Whether validate the underlying format or not. */
val formatValidationEnabled: Boolean = sqlConf.stateStoreFormatValidationEnabled

/** Whether validate the value format when the format invalidation enabled. */
val formatValidationCheckValue: Boolean =
extraOptions.getOrElse(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG, "true") == "true"

/**
* Additional configurations related to state store. This will capture all configs in
* SQLConf that start with `spark.sql.streaming.stateStore.` */
* SQLConf that start with `spark.sql.streaming.stateStore.` and extraOptions for a specific
* operator.
*/
val confs: Map[String, String] =
sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore."))
sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) ++ extraOptions
}

object StateStoreConf {
val FORMAT_VALIDATION_CHECK_VALUE_CONFIG = "formatValidationCheckValue"

val empty = new StateStoreConf()

def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
valueSchema: StructType,
indexOrdinal: Option[Int],
sessionState: SessionState,
@transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
@transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
extraOptions: Map[String, String] = Map.empty)
extends RDD[U](dataRDD) {

private val storeConf = new StateStoreConf(sessionState.conf)
private val storeConf = new StateStoreConf(sessionState.conf, extraOptions)

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
private val hadoopConfBroadcast = dataRDD.context.broadcast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ package object state {
valueSchema: StructType,
indexOrdinal: Option[Int],
sessionState: SessionState,
storeCoordinator: Option[StateStoreCoordinatorRef])(
storeCoordinator: Option[StateStoreCoordinatorRef],
extraOptions: Map[String, String] = Map.empty)(
storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {

val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
Expand All @@ -78,7 +79,8 @@ package object state {
valueSchema,
indexOrdinal,
sessionState,
storeCoordinator)
storeCoordinator,
extraOptions)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
import org.apache.spark.util.{CompletionIterator, NextIterator, Utils}
Expand Down Expand Up @@ -460,7 +461,10 @@ case class StreamingDeduplicateExec(
child.output.toStructType,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
Some(sqlContext.streams.stateStoreCoordinator),
// We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW
// is unrelated to the output schema.
Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "false")) { (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
val numOutputRows = longMetric("numOutputRows")
val numTotalStateRows = longMetric("numTotalStateRows")
Expand Down