-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31894][SS] Introduce UnsafeRow format validation for streaming state store #28707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1119756
2153abf
179208a
4c919ca
b83f0c3
0313016
fc5ad19
12eb2a2
01007fb
fd74ff9
557eb30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.util | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.UnsafeRow | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| object UnsafeRowUtils { | ||
|
|
||
| /** | ||
| * Use the following rules to check the integrity of the UnsafeRow: | ||
| * - schema.fields.length == row.numFields should always be true | ||
| * - UnsafeRow.calculateBitSetWidthInBytes(row.numFields) < row.getSizeInBytes should always be | ||
| * true if the expectedSchema contains at least one field. | ||
| * - For variable-length fields: if null bit says it's null then don't do anything, else extract | ||
| * offset and size: | ||
| * 1) 0 <= size < row.getSizeInBytes should always be true. We can be even more precise than | ||
| * this, where the upper bound of size can only be as big as the variable length part of | ||
| * the row. | ||
| * 2) offset should be >= fixed sized part of the row. | ||
| * 3) offset + size should be within the row bounds. | ||
| * - For fixed-length fields that are narrower than 8 bytes (boolean/byte/short/int/float), if | ||
| * null bit says it's null then don't do anything, else: | ||
| * check if the unused bits in the field are all zeros. The UnsafeRowWriter's write() methods | ||
| * make this guarantee. | ||
| * - Check the total length of the row. | ||
| */ | ||
| def validateStructuralIntegrity(row: UnsafeRow, expectedSchema: StructType): Boolean = { | ||
| if (expectedSchema.fields.length != row.numFields) { | ||
| return false | ||
| } | ||
| val bitSetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields) | ||
| val rowSizeInBytes = row.getSizeInBytes | ||
| if (expectedSchema.fields.length > 0 && bitSetWidthInBytes >= rowSizeInBytes) { | ||
| return false | ||
| } | ||
| var varLenFieldsSizeInBytes = 0 | ||
| expectedSchema.fields.zipWithIndex.foreach { | ||
| case (field, index) if !UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) => | ||
| val offsetAndSize = row.getLong(index) | ||
| val offset = (offsetAndSize >> 32).toInt | ||
| val size = offsetAndSize.toInt | ||
| if (size < 0 || | ||
| offset < bitSetWidthInBytes + 8 * row.numFields || offset + size > rowSizeInBytes) { | ||
| return false | ||
| } | ||
| varLenFieldsSizeInBytes += size | ||
| case (field, index) if UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) => | ||
| field.dataType match { | ||
| case BooleanType => | ||
| if ((row.getLong(index) >> 1) != 0L) return false | ||
| case ByteType => | ||
| if ((row.getLong(index) >> 8) != 0L) return false | ||
| case ShortType => | ||
| if ((row.getLong(index) >> 16) != 0L) return false | ||
| case IntegerType => | ||
| if ((row.getLong(index) >> 32) != 0L) return false | ||
| case FloatType => | ||
| if ((row.getLong(index) >> 32) != 0L) return false | ||
| case _ => | ||
| } | ||
| case (_, index) if row.isNullAt(index) => | ||
| if (row.getLong(index) != 0L) return false | ||
| case _ => | ||
| } | ||
| if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) { | ||
| return false | ||
| } | ||
| true | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1237,6 +1237,16 @@ object SQLConf { | |
| .intConf | ||
| .createWithDefault(10) | ||
|
|
||
| val STATE_STORE_FORMAT_VALIDATION_ENABLED = | ||
| buildConf("spark.sql.streaming.stateStore.formatValidation.enabled") | ||
| .internal() | ||
| .doc("When true, check if the UnsafeRow from the state store is valid or not when running " + | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will submit a follow-up PR today. |
||
| "streaming queries. This can happen if the state store format has been changed. Note, " + | ||
| "the feature is only effective in the build-in HDFS state store provider now.") | ||
| .version("3.1.0") | ||
| .booleanConf | ||
| .createWithDefault(true) | ||
|
|
||
| val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = | ||
| buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") | ||
| .internal() | ||
|
|
@@ -1543,18 +1553,18 @@ object SQLConf { | |
|
|
||
| val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = | ||
| buildConf("spark.sql.streaming.checkpointFileManagerClass") | ||
| .internal() | ||
| .doc("The class used to write checkpoint files atomically. This class must be a subclass " + | ||
| "of the interface CheckpointFileManager.") | ||
| .version("2.4.0") | ||
| .internal() | ||
| .stringConf | ||
|
|
||
| val STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED = | ||
| buildConf("spark.sql.streaming.checkpoint.escapedPathCheck.enabled") | ||
| .internal() | ||
| .doc("Whether to detect a streaming query may pick up an incorrect checkpoint path due " + | ||
| "to SPARK-26824.") | ||
| .version("3.0.0") | ||
| .internal() | ||
| .booleanConf | ||
| .createWithDefault(true) | ||
|
|
||
|
|
@@ -2746,6 +2756,8 @@ class SQLConf extends Serializable with Logging { | |
|
|
||
| def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) | ||
|
|
||
| def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) | ||
|
|
||
| def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) | ||
|
|
||
| def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.util | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow} | ||
| import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} | ||
|
|
||
| class UnsafeRowUtilsSuite extends SparkFunSuite { | ||
|
|
||
| val testKeys: Seq[String] = Seq("key1", "key2") | ||
| val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)") | ||
|
|
||
| val testOutputSchema: StructType = StructType( | ||
| testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) | ||
|
|
||
| val testRow: UnsafeRow = { | ||
| val unsafeRowProjection = UnsafeProjection.create(testOutputSchema) | ||
| val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema)) | ||
| (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } | ||
| row | ||
| } | ||
|
|
||
| private def createIntegerField(name: String): StructField = { | ||
| StructField(name, IntegerType, nullable = false) | ||
| } | ||
|
|
||
| test("UnsafeRow format invalidation") { | ||
| // Pass the checking | ||
| UnsafeRowUtils.validateStructuralIntegrity(testRow, testOutputSchema) | ||
| // Fail for fields number not match | ||
| assert(!UnsafeRowUtils.validateStructuralIntegrity( | ||
| testRow, StructType(testKeys.map(createIntegerField)))) | ||
| // Fail for invalid schema | ||
| val invalidSchema = StructType(testKeys.map(createIntegerField) ++ | ||
| Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true), | ||
| StructField("value2", IntegerType, false))) | ||
| assert(!UnsafeRowUtils.validateStructuralIntegrity(testRow, invalidSchema)) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -259,6 +259,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit | |
| @volatile private var storeConf: StateStoreConf = _ | ||
| @volatile private var hadoopConf: Configuration = _ | ||
| @volatile private var numberOfVersionsToRetainInMemory: Int = _ | ||
| // TODO: The validation should be moved to a higher level so that it works for all state store | ||
| // implementations | ||
| @volatile private var isValidated = false | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a TODO that this validation should be moved to a higher level so that it works for all state store implementations?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, add the TODO in fd74ff9. |
||
|
|
||
| private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse) | ||
| private lazy val baseDir = stateStoreId.storeCheckpointLocation() | ||
|
|
@@ -457,6 +460,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit | |
| // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in | ||
| // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data | ||
| valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) | ||
| if (!isValidated) { | ||
| StateStoreProvider.validateStateRowFormat( | ||
| keyRow, keySchema, valueRow, valueSchema, storeConf) | ||
| isValidated = true | ||
| } | ||
| map.put(keyRow, valueRow) | ||
| } | ||
| } | ||
|
|
@@ -551,6 +559,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit | |
| // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in | ||
| // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data | ||
| valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) | ||
| if (!isValidated) { | ||
| StateStoreProvider.validateStateRowFormat( | ||
| keyRow, keySchema, valueRow, valueSchema, storeConf) | ||
| isValidated = true | ||
| } | ||
| map.put(keyRow, valueRow) | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,9 +27,10 @@ import scala.util.control.NonFatal | |
| import org.apache.hadoop.conf.Configuration | ||
| import org.apache.hadoop.fs.Path | ||
|
|
||
| import org.apache.spark.{SparkContext, SparkEnv} | ||
| import org.apache.spark.{SparkContext, SparkEnv, SparkException} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.catalyst.expressions.UnsafeRow | ||
| import org.apache.spark.sql.catalyst.util.UnsafeRowUtils | ||
| import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.util.{ThreadUtils, Utils} | ||
|
|
@@ -143,6 +144,16 @@ case class StateStoreCustomSumMetric(name: String, desc: String) extends StateSt | |
| case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric | ||
| case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric | ||
|
|
||
| /** | ||
| * An exception thrown when an invalid UnsafeRow is detected in state store. | ||
| */ | ||
| class InvalidUnsafeRowException | ||
| extends RuntimeException("The streaming query failed by state format invalidation. " + | ||
| "The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " + | ||
| "incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " + | ||
| "among restart. For the first case, you can try to restart the application without " + | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The resolution is for the first case. For the rest cases listing, they should be considered as user problems. |
||
| "checkpoint or use the legacy Spark version to process the streaming state.", null) | ||
|
|
||
| /** | ||
| * Trait representing a provider that provide [[StateStore]] instances representing | ||
| * versions of state data. | ||
|
|
@@ -230,6 +241,26 @@ object StateStoreProvider { | |
| provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) | ||
| provider | ||
| } | ||
|
|
||
| /** | ||
| * Use the expected schema to check whether the UnsafeRow is valid. | ||
| */ | ||
| def validateStateRowFormat( | ||
| keyRow: UnsafeRow, | ||
| keySchema: StructType, | ||
| valueRow: UnsafeRow, | ||
| valueSchema: StructType, | ||
| conf: StateStoreConf): Unit = { | ||
| if (conf.formatValidationEnabled) { | ||
| if (!UnsafeRowUtils.validateStructuralIntegrity(keyRow, keySchema)) { | ||
| throw new InvalidUnsafeRowException | ||
| } | ||
| if (conf.formatValidationCheckValue && | ||
| !UnsafeRowUtils.validateStructuralIntegrity(valueRow, valueSchema)) { | ||
| throw new InvalidUnsafeRowException | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.