diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala new file mode 100644 index 000000000000..37a34fac6636 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types._ + +object UnsafeRowUtils { + + /** + * Use the following rules to check the integrity of the UnsafeRow: + * - schema.fields.length == row.numFields should always be true + * - UnsafeRow.calculateBitSetWidthInBytes(row.numFields) < row.getSizeInBytes should always be + * true if the expectedSchema contains at least one field. + * - For variable-length fields: if null bit says it's null then don't do anything, else extract + * offset and size: + * 1) 0 <= size < row.getSizeInBytes should always be true. We can be even more precise than + * this, where the upper bound of size can only be as big as the variable length part of + * the row. + * 2) offset should be >= fixed sized part of the row. + * 3) offset + size should be within the row bounds. + * - For fixed-length fields that are narrower than 8 bytes (boolean/byte/short/int/float), if + * null bit says it's null then don't do anything, else: + * check if the unused bits in the field are all zeros. The UnsafeRowWriter's write() methods + * make this guarantee. + * - Check the total length of the row. + */ + def validateStructuralIntegrity(row: UnsafeRow, expectedSchema: StructType): Boolean = { + if (expectedSchema.fields.length != row.numFields) { + return false + } + val bitSetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields) + val rowSizeInBytes = row.getSizeInBytes + if (expectedSchema.fields.length > 0 && bitSetWidthInBytes >= rowSizeInBytes) { + return false + } + var varLenFieldsSizeInBytes = 0 + expectedSchema.fields.zipWithIndex.foreach { + case (field, index) if !UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) => + val offsetAndSize = row.getLong(index) + val offset = (offsetAndSize >> 32).toInt + val size = offsetAndSize.toInt + if (size < 0 || + offset < bitSetWidthInBytes + 8 * row.numFields || offset + size > rowSizeInBytes) { + return false + } + varLenFieldsSizeInBytes += size + case (field, index) if UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) => + field.dataType match { + case BooleanType => + if ((row.getLong(index) >> 1) != 0L) return false + case ByteType => + if ((row.getLong(index) >> 8) != 0L) return false + case ShortType => + if ((row.getLong(index) >> 16) != 0L) return false + case IntegerType => + if ((row.getLong(index) >> 32) != 0L) return false + case FloatType => + if ((row.getLong(index) >> 32) != 0L) return false + case _ => + } + case (_, index) if row.isNullAt(index) => + if (row.getLong(index) != 0L) return false + case _ => + } + if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) { + return false + } + true + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7f63d79a21ed..6bbeb2de7538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1237,6 +1237,16 @@ object SQLConf { .intConf .createWithDefault(10) + val STATE_STORE_FORMAT_VALIDATION_ENABLED = + buildConf("spark.sql.streaming.stateStore.formatValidation.enabled") + .internal() + .doc("When true, check if the UnsafeRow from the state store is valid or not when running " + + "streaming queries. This can happen if the state store format has been changed. Note, " + + "the feature is only effective in the build-in HDFS state store provider now.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") .internal() @@ -1543,18 +1553,18 @@ object SQLConf { val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = buildConf("spark.sql.streaming.checkpointFileManagerClass") + .internal() .doc("The class used to write checkpoint files atomically. This class must be a subclass " + "of the interface CheckpointFileManager.") .version("2.4.0") - .internal() .stringConf val STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED = buildConf("spark.sql.streaming.checkpoint.escapedPathCheck.enabled") + .internal() .doc("Whether to detect a streaming query may pick up an incorrect checkpoint path due " + "to SPARK-26824.") .version("3.0.0") - .internal() .booleanConf .createWithDefault(true) @@ -2746,6 +2756,8 @@ class SQLConf extends Serializable with Logging { def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala new file mode 100644 index 000000000000..4b6a3cfafd89 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +class UnsafeRowUtilsSuite extends SparkFunSuite { + + val testKeys: Seq[String] = Seq("key1", "key2") + val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)") + + val testOutputSchema: StructType = StructType( + testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) + + val testRow: UnsafeRow = { + val unsafeRowProjection = UnsafeProjection.create(testOutputSchema) + val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema)) + (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } + row + } + + private def createIntegerField(name: String): StructField = { + StructField(name, IntegerType, nullable = false) + } + + test("UnsafeRow format invalidation") { + // Pass the checking + UnsafeRowUtils.validateStructuralIntegrity(testRow, testOutputSchema) + // Fail for fields number not match + assert(!UnsafeRowUtils.validateStructuralIntegrity( + testRow, StructType(testKeys.map(createIntegerField)))) + // Fail for invalid schema + val invalidSchema = StructType(testKeys.map(createIntegerField) ++ + Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true), + StructField("value2", IntegerType, false))) + assert(!UnsafeRowUtils.validateStructuralIntegrity(testRow, invalidSchema)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 05c651f9951b..31618922e44c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -259,6 +259,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ @volatile private var numberOfVersionsToRetainInMemory: Int = _ + // TODO: The validation should be moved to a higher level so that it works for all state store + // implementations + @volatile private var isValidated = false private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse) private lazy val baseDir = stateStoreId.storeCheckpointLocation() @@ -457,6 +460,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) + if (!isValidated) { + StateStoreProvider.validateStateRowFormat( + keyRow, keySchema, valueRow, valueSchema, storeConf) + isValidated = true + } map.put(keyRow, valueRow) } } @@ -551,6 +559,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) + if (!isValidated) { + StateStoreProvider.validateStateRowFormat( + keyRow, keySchema, valueRow, valueSchema, storeConf) + isValidated = true + } map.put(keyRow, valueRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7d80fd0c591f..092ca968f59c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -27,9 +27,10 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, SparkEnv, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} @@ -143,6 +144,16 @@ case class StateStoreCustomSumMetric(name: String, desc: String) extends StateSt case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric +/** + * An exception thrown when an invalid UnsafeRow is detected in state store. + */ +class InvalidUnsafeRowException + extends RuntimeException("The streaming query failed by state format invalidation. " + + "The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " + + "incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " + + "among restart. For the first case, you can try to restart the application without " + + "checkpoint or use the legacy Spark version to process the streaming state.", null) + /** * Trait representing a provider that provide [[StateStore]] instances representing * versions of state data. @@ -230,6 +241,26 @@ object StateStoreProvider { provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) provider } + + /** + * Use the expected schema to check whether the UnsafeRow is valid. + */ + def validateStateRowFormat( + keyRow: UnsafeRow, + keySchema: StructType, + valueRow: UnsafeRow, + valueSchema: StructType, + conf: StateStoreConf): Unit = { + if (conf.formatValidationEnabled) { + if (!UnsafeRowUtils.validateStructuralIntegrity(keyRow, keySchema)) { + throw new InvalidUnsafeRowException + } + if (conf.formatValidationCheckValue && + !UnsafeRowUtils.validateStructuralIntegrity(valueRow, valueSchema)) { + throw new InvalidUnsafeRowException + } + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index d145082a39b5..84d0b76ac915 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -class StateStoreConf(@transient private val sqlConf: SQLConf) +class StateStoreConf( + @transient private val sqlConf: SQLConf, + extraOptions: Map[String, String] = Map.empty) extends Serializable { def this() = this(new SQLConf) @@ -43,14 +45,25 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) */ val providerClass: String = sqlConf.stateStoreProviderClass + /** Whether validate the underlying format or not. */ + val formatValidationEnabled: Boolean = sqlConf.stateStoreFormatValidationEnabled + + /** Whether validate the value format when the format invalidation enabled. */ + val formatValidationCheckValue: Boolean = + extraOptions.getOrElse(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG, "true") == "true" + /** * Additional configurations related to state store. This will capture all configs in - * SQLConf that start with `spark.sql.streaming.stateStore.` */ + * SQLConf that start with `spark.sql.streaming.stateStore.` and extraOptions for a specific + * operator. + */ val confs: Map[String, String] = - sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) + sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) ++ extraOptions } object StateStoreConf { + val FORMAT_VALIDATION_CHECK_VALUE_CONFIG = "formatValidationCheckValue" + val empty = new StateStoreConf() def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 4a69a48fed75..0eb3dce1bbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -45,10 +45,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( valueSchema: StructType, indexOrdinal: Option[Int], sessionState: SessionState, - @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef], + extraOptions: Map[String, String] = Map.empty) extends RDD[U](dataRDD) { - private val storeConf = new StateStoreConf(sessionState.conf) + private val storeConf = new StateStoreConf(sessionState.conf, extraOptions) // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val hadoopConfBroadcast = dataRDD.context.broadcast( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b6021438e902..c7a332b6d778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -55,7 +55,8 @@ package object state { valueSchema: StructType, indexOrdinal: Option[Int], sessionState: SessionState, - storeCoordinator: Option[StateStoreCoordinatorRef])( + storeCoordinator: Option[StateStoreCoordinatorRef], + extraOptions: Map[String, String] = Map.empty)( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) @@ -78,7 +79,8 @@ package object state { valueSchema, indexOrdinal, sessionState, - storeCoordinator) + storeCoordinator, + extraOptions) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 74daaf80b10e..a9c01e69b9b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} @@ -460,7 +461,10 @@ case class StreamingDeduplicateExec( child.output.toStructType, indexOrdinal = None, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + Some(sqlContext.streams.stateStoreCoordinator), + // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW + // is unrelated to the output schema. + Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "false")) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows")