Skip to content
Prev Previous commit
Next Next commit
Address comments:1.enhance the validation of unsafe row; 2.check the …
…format for all state store
  • Loading branch information
xuanyuanking committed Jun 17, 2020
commit b83f0c309f3c2dd02dcfc65f39d2ae23864b69c6
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types._

object UnsafeRowUtils {

/**
* Use the following rules to check the integrity of the UnsafeRow:
* - schema.fields.length == row.numFields should always be true
* - UnsafeRow.calculateBitSetWidthInBytes(row.numFields) < row.getSizeInBytes should always be
* true if the expectedSchema contains at least one field.
* - For variable-length fields: if null bit says it's null then don't do anything, else extract
* offset and size:
* 1) 0 <= size < row.getSizeInBytes should always be true. We can be even more precise than
* this, where the upper bound of size can only be as big as the variable length part of
* the row.
* 2) offset should be >= fixed sized part of the row.
* 3) offset + size should be within the row bounds.
* - For fixed-length fields that are narrower than 8 bytes (boolean/byte/short/int/float), if
* null bit says it's null then don't do anything, else:
* check if the unused bits in the field are all zeros. The UnsafeRowWriter's write() methods
* make this guarantee.
* - Check the total length of the row.
*/
def validateStructuralIntegrity(row: UnsafeRow, expectedSchema: StructType): Boolean = {
if (expectedSchema.fields.length != row.numFields) {
return false
}
val bitSetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields)
val rowSizeInBytes = row.getSizeInBytes
if (expectedSchema.fields.length > 0 && bitSetWidthInBytes >= rowSizeInBytes) {
return false
}
var varLenFieldsSizeInBytes = 0
expectedSchema.fields.zipWithIndex.foreach {
case (field, index) if !UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) =>
val offsetAndSize = row.getLong(index)
val offset = (offsetAndSize >> 32).toInt
val size = offsetAndSize.toInt
if (size < 0 ||
offset < bitSetWidthInBytes + 8 * row.numFields || offset + size > rowSizeInBytes) {
return false
}
varLenFieldsSizeInBytes += size
case (field, index) if UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) =>
field.dataType match {
case BooleanType =>
if ((row.getLong(index) >> 1) != 0L) return false
case ByteType =>
if ((row.getLong(index) >> 8) != 0L) return false
case ShortType =>
if ((row.getLong(index) >> 16) != 0L) return false
case IntegerType =>
if ((row.getLong(index) >> 32) != 0L) return false
case FloatType =>
if ((row.getLong(index) >> 32) != 0L) return false
case _ =>
}
}
if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) {
return false
}
true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1558,12 +1558,11 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val STREAMING_AGGREGATION_STATE_FORMAT_CHECK_ENABLED =
buildConf("spark.sql.streaming.aggregationStateFormatCheck.enabled")
val STREAMING_STATE_FORMAT_CHECK_ENABLED =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is misleading - we're only detecting the case from streaming aggregation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW should we have configuration for this, given that this only does essential check which all rows must have been passed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, rename it in ee048bc. Considering it's an extra checking and still have overhead, I keep the feature flag for safety.

buildConf("spark.sql.streaming.stateFormatCheck.enabled")
.internal()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually put internal() right after buildConf(...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done in 10a7980.

.doc("When true, check if the UnsafeRow from the state store is valid or not when running " +
"streaming aggregation queries. This can happen if the state store format has been " +
"changed.")
"streaming queries. This can happen if the state store format has been changed.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class UnsafeRowUtilsSuite extends SparkFunSuite {

val testKeys: Seq[String] = Seq("key1", "key2")
val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)")

val testOutputSchema: StructType = StructType(
testKeys.map(createIntegerField) ++ testValues.map(createIntegerField))

val testRow: UnsafeRow = {
val unsafeRowProjection = UnsafeProjection.create(testOutputSchema)
val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema))
(testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) }
row
}

private def createIntegerField(name: String): StructField = {
StructField(name, IntegerType, nullable = false)
}

test("UnsafeRow format invalidation") {
// Pass the checking
UnsafeRowUtils.validateStructuralIntegrity(testRow, testOutputSchema)
// Fail for fields number not match
assert(!UnsafeRowUtils.validateStructuralIntegrity(
testRow, StructType(testKeys.map(createIntegerField))))
// Fail for invalid schema
val invalidSchema = StructType(testKeys.map(createIntegerField) ++
Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true),
StructField("value2", IntegerType, false)))
assert(!UnsafeRowUtils.validateStructuralIntegrity(testRow, invalidSchema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
@volatile private var storeConf: StateStoreConf = _
@volatile private var hadoopConf: Configuration = _
@volatile private var numberOfVersionsToRetainInMemory: Int = _
@volatile private var isValidated = false
Copy link
Contributor

@cloud-fan cloud-fan Jun 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a TODO that this validation should be moved to a higher level so that it works for all state store implementations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, add the TODO in fd74ff9.


private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse)
private lazy val baseDir = stateStoreId.storeCheckpointLocation()
Expand Down Expand Up @@ -457,6 +458,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
// Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
if (!isValidated) {
StateStoreProvider.validateStateRowFormat(keyRow, keySchema)
StateStoreProvider.validateStateRowFormat(valueRow, valueSchema)
isValidated = true
}
map.put(keyRow, valueRow)
}
}
Expand Down Expand Up @@ -551,6 +557,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
// Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
if (!isValidated) {
StateStoreProvider.validateStateRowFormat(keyRow, keySchema)
StateStoreProvider.validateStateRowFormat(valueRow, valueSchema)
isValidated = true
}
map.put(keyRow, valueRow)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.{SparkContext, SparkEnv, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ThreadUtils, Utils}

Expand Down Expand Up @@ -143,6 +145,16 @@ case class StateStoreCustomSumMetric(name: String, desc: String) extends StateSt
case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric

/**
* An exception thrown when an invalid UnsafeRow is detected in state store.
*/
class InvalidUnsafeRowException
extends SparkException("The streaming query failed by state format invalidation. " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it have to be SparkException?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, change it to RuntimeException. Done in fd74ff9.

"The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " +
"incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " +
"among restart. For the first case, you can try to restart the application without " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first case: I think it's for the cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resolution is for the first case. For the rest cases listing, they should be considered as user problems.

"checkpoint or use the legacy Spark version to process the streaming state.", null)

/**
* Trait representing a provider that provide [[StateStore]] instances representing
* versions of state data.
Expand Down Expand Up @@ -230,6 +242,17 @@ object StateStoreProvider {
provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
provider
}

/**
* Use the expected schema to check whether the UnsafeRow is valid.
*/
def validateStateRowFormat(row: UnsafeRow, schema: StructType): Unit = {
if (SQLConf.get.getConf(SQLConf.STREAMING_STATE_FORMAT_CHECK_ENABLED)) {
if (!UnsafeRowUtils.validateStructuralIntegrity(row, schema)) {
throw new InvalidUnsafeRowException
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.sql.execution.streaming.state

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -61,9 +59,6 @@ sealed trait StreamingAggregationStateManager extends Serializable {

/** Return an iterator containing all the values in target state store. */
def values(store: StateStore): Iterator[UnsafeRow]

/** Check the UnsafeRow format with the expected schema */
def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit
}

object StreamingAggregationStateManager extends Logging {
Expand All @@ -82,26 +77,13 @@ object StreamingAggregationStateManager extends Logging {
}
}

/**
* An exception thrown when an invalid UnsafeRow is detected.
*/
class InvalidUnsafeRowException
extends SparkException("The streaming aggregation query failed by state format invalidation. " +
"The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " +
"incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " +
"among restart. For the first case, you can try to restart the application without " +
"checkpoint or use the legacy Spark version to process the streaming state.", null)

abstract class StreamingAggregationStateManagerBaseImpl(
protected val keyExpressions: Seq[Attribute],
protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager {

@transient protected lazy val keyProjector =
GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes)

// Consider about the cost, only check the UnsafeRow format for the first row
private var checkFormat = true

override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row)

override def commit(store: StateStore): Long = store.commit()
Expand All @@ -112,28 +94,6 @@ abstract class StreamingAggregationStateManagerBaseImpl(
// discard and don't convert values to avoid computation
store.getRange(None, None).map(_.key)
}

override def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit = {
if (checkFormat && SQLConf.get.getConf(
SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_CHECK_ENABLED) && row != null) {
if (schema.fields.length != row.numFields) {
throw new InvalidUnsafeRowException
}
schema.fields.zipWithIndex
.filterNot(field => UnsafeRow.isFixedLength(field._1.dataType)).foreach {
case (_, index) =>
val offsetAndSize = row.getLong(index)
val offset = (offsetAndSize >> 32).toInt
val size = offsetAndSize.toInt
if (size < 0 ||
offset < UnsafeRow.calculateBitSetWidthInBytes(row.numFields) + 8 * row.numFields ||
offset + size > row.getSizeInBytes) {
throw new InvalidUnsafeRowException
}
}
checkFormat = false
}
}
}

/**
Expand All @@ -154,9 +114,7 @@ class StreamingAggregationStateManagerImplV1(
override def getStateValueSchema: StructType = inputRowAttributes.toStructType

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
val res = store.get(key)
unsafeRowFormatValidation(res, inputRowAttributes.toStructType)
res
store.get(key)
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
Expand Down Expand Up @@ -215,9 +173,7 @@ class StreamingAggregationStateManagerImplV2(
return savedState
}

val res = restoreOriginalRow(key, savedState)
unsafeRowFormatValidation(res, inputRowAttributes.toStructType)
res
restoreOriginalRow(key, savedState)
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,24 +123,4 @@ class StreamingAggregationStateManagerSuite extends StreamTest {
// state manager should return row which is same as input row regardless of format version
assert(inputRow === stateManager.get(memoryStateStore, keyRow))
}

test("UnsafeRow format invalidation") {
// Pass the checking
val stateManager0 = StreamingAggregationStateManager.createStateManager(testKeyAttributes,
testOutputAttributes, 2)
stateManager0.unsafeRowFormatValidation(testRow, testOutputSchema)
// Fail for fields number not match
val stateManager1 = StreamingAggregationStateManager.createStateManager(testKeyAttributes,
testOutputAttributes, 2)
assertThrows[InvalidUnsafeRowException](stateManager1.unsafeRowFormatValidation(
testRow, StructType(testKeys.map(createIntegerField))))
// Fail for invalid schema
val stateManager2 = StreamingAggregationStateManager.createStateManager(testKeyAttributes,
testOutputAttributes, 2)
val invalidSchema = StructType(testKeys.map(createIntegerField) ++
Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true),
StructField("value2", IntegerType, false)))
assertThrows[InvalidUnsafeRowException](stateManager2.unsafeRowFormatValidation(
testRow, invalidSchema))
}
}
Loading