From 1119756323c7349ece51ed16c006748a20670a75 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 3 Jun 2020 02:28:33 +0800 Subject: [PATCH 01/11] initial commit --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../StreamingAggregationStateManager.scala | 46 ++++++++++++++++++- ...treamingAggregationStateManagerSuite.scala | 22 ++++++++- 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7f63d79a21ed..c278a9c3d9fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1558,6 +1558,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STREAMING_STATE_FORMAT_CHECK_ENABLED = + buildConf("spark.sql.streaming.stateFormatCheck.enabled") + .doc("Whether to detect a streaming query may try to use an invalid UnsafeRow in the " + + "state store.") + .version("3.1.0") + .internal() + .booleanConf + .createWithDefault(true) + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala index 9bfb9561b42a..f963b5c95f5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.streaming.state +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** @@ -59,6 +61,9 @@ sealed trait StreamingAggregationStateManager extends Serializable { /** Return an iterator containing all the values in target state store. */ def values(store: StateStore): Iterator[UnsafeRow] + + /** Check the UnsafeRow format with the expected schema */ + def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit } object StreamingAggregationStateManager extends Logging { @@ -77,6 +82,14 @@ object StreamingAggregationStateManager extends Logging { } } +/** + * An exception thrown when an invalid UnsafeRow is detected. + */ +class InvalidUnsafeRowException + extends SparkException("The UnsafeRow format is invalid. This may happen when using the old " + + "version or broken checkpoint file. To resolve this problem, you can try to restart the " + + "application or use the legacy way to process streaming state.", null) + abstract class StreamingAggregationStateManagerBaseImpl( protected val keyExpressions: Seq[Attribute], protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { @@ -84,6 +97,9 @@ abstract class StreamingAggregationStateManagerBaseImpl( @transient protected lazy val keyProjector = GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) + // Consider about the cost, only check the UnsafeRow format for the first row + private var checkFormat = true + override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) override def commit(store: StateStore): Long = store.commit() @@ -94,6 +110,28 @@ abstract class StreamingAggregationStateManagerBaseImpl( // discard and don't convert values to avoid computation store.getRange(None, None).map(_.key) } + + override def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit = { + if (checkFormat && SQLConf.get.getConf( + SQLConf.STREAMING_STATE_FORMAT_CHECK_ENABLED) && row != null) { + if (schema.fields.length != row.numFields) { + throw new InvalidUnsafeRowException + } + schema.fields.zipWithIndex + .filterNot(field => UnsafeRow.isFixedLength(field._1.dataType)).foreach { + case (_, index) => + val offsetAndSize = row.getLong(index) + val offset = (offsetAndSize >> 32).toInt + val size = offsetAndSize.toInt + if (size < 0 || + offset < UnsafeRow.calculateBitSetWidthInBytes(row.numFields) + 8 * row.numFields || + offset + size > row.getSizeInBytes) { + throw new InvalidUnsafeRowException + } + } + checkFormat = false + } + } } /** @@ -114,7 +152,9 @@ class StreamingAggregationStateManagerImplV1( override def getStateValueSchema: StructType = inputRowAttributes.toStructType override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { - store.get(key) + val res = store.get(key) + unsafeRowFormatValidation(res, inputRowAttributes.toStructType) + res } override def put(store: StateStore, row: UnsafeRow): Unit = { @@ -173,7 +213,9 @@ class StreamingAggregationStateManagerImplV2( return savedState } - restoreOriginalRow(key, savedState) + val res = restoreOriginalRow(key, savedState) + unsafeRowFormatValidation(res, inputRowAttributes.toStructType) + res } override def put(store: StateStore, row: UnsafeRow): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index daacdfd58c7b..2881e2e6f6c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class StreamingAggregationStateManagerSuite extends StreamTest { // ============================ fields and method for test data ============================ @@ -123,4 +123,24 @@ class StreamingAggregationStateManagerSuite extends StreamTest { // state manager should return row which is same as input row regardless of format version assert(inputRow === stateManager.get(memoryStateStore, keyRow)) } + + test("UnsafeRow format invalidation") { + // Pass the checking + val stateManager0 = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 2) + stateManager0.unsafeRowFormatValidation(testRow, testOutputSchema) + // Fail for fields number not match + val stateManager1 = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 2) + assertThrows[InvalidUnsafeRowException](stateManager1.unsafeRowFormatValidation( + testRow, StructType(testKeys.map(createIntegerField)))) + // Fail for invalid schema + val stateManager2 = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 2) + val invalidSchema = StructType(testKeys.map(createIntegerField) ++ + Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true), + StructField("value2", IntegerType, false))) + assertThrows[InvalidUnsafeRowException](stateManager2.unsafeRowFormatValidation( + testRow, invalidSchema)) + } } From 2153abfd3be86a659daa1bb9b9bc29cda4bf3665 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 3 Jun 2020 01:49:08 +0800 Subject: [PATCH 02/11] initial commit --- ...reamingAggregationCompatibilitySuite.scala | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationCompatibilitySuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationCompatibilitySuite.scala new file mode 100644 index 000000000000..9ecbbad25ef7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationCompatibilitySuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.File + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.util.Utils + +class StreamingAggregationCompatibilitySuite extends StreamTest { + import testImplicits._ + + test("common functions") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF().toDF("value") + .selectExpr( + "value", + "value % 5 AS id", + "CAST(value AS STRING) as str", + "CAST(value AS FLOAT) as f", + "CAST(value AS DOUBLE) as d", + "CAST(value AS DECIMAL) as dec", + "value % 3 AS mod", + "named_struct('key', CAST(value AS STRING), 'value', value) AS s") + .groupBy($"id") + .agg( + avg($"value").as("avg_v"), + avg($"f").as("avg_f"), + avg($"d").as("avg_d"), + avg($"dec").as("avg_dec"), + count($"value").as("cnt"), + first($"value").as("first_v"), + first($"s").as("first_s"), + last($"value").as("last_v"), + last($"s").as("last_s"), + min(struct("value", "str")).as("min_struct"), + max($"value").as("max_v"), + sum($"value").as("sum_v"), + sum($"f").as("sum_f"), + sum($"d").as("sum_d"), + sum($"dec").as("sum_dec"), + collect_list($"value").as("col_list"), + collect_set($"mod").as("col_set")) + .select("id", "avg_v", "avg_f", "avg_d", "avg_dec", "cnt", "first_v", "first_s.value", + "last_v", "last_s.value", "min_struct.value", "max_v", "sum_v", "sum_f", "sum_d", + "sum_dec", "col_list", "col_set") + + val resourceUri = this.getClass.getResource("/structured-streaming/" + + "checkpoint-version-2.4.5-for-compatibility-test-common-functions").toURI + val checkpointDir = Utils.createTempDir().getCanonicalFile + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(0 to 9: _*) + + testStream(aggregated, Complete)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), +// AddData(inputData, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9), +// CheckAnswer( +// Row(0, 2.5, 2.5F, 2.5, 2.5000, 2, 0, 0, 5, 5, 0, 5, 5, 5.0, 5.0, 5, Seq(0, 5), +// Seq(0, 2)), +// Row(1, 3.5, 3.5F, 3.5, 3.5000, 2, 1, 1, 6, 6, 1, 6, 7, 7.0, 7.0, 7, Seq(1, 6), +// Seq(0, 1)), +// Row(2, 4.5, 4.5F, 4.5, 4.5000, 2, 2, 2, 7, 7, 2, 7, 9, 9.0, 9.0, 9, Seq(2, 7), +// Seq(1, 2)), +// Row(3, 5.5, 5.5F, 5.5, 5.5000, 2, 3, 3, 8, 8, 3, 8, 11, 11.0, 11.0, 11, Seq(3, 8), +// Seq(0, 2)), +// Row(4, 6.5, 6.5F, 6.5, 6.5000, 2, 4, 4, 9, 9, 4, 9, 13, 13.0, 13.0, 13, Seq(4, 9), +// Seq(0, 1))), + AddData(inputData, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), + CheckAnswer( + Row(0, 7.5, 7.5, 7.5, 7.5000, 4, 0, 0, 15, 15, 0, 15, 30, 30.0, 30.0, 30, + Seq(0, 5, 10, 15), Seq(0, 1, 2)), + Row(1, 8.5, 8.5, 8.5, 8.5000, 4, 1, 1, 16, 16, 1, 16, 34, 34.0, 34.0, 34, + Seq(1, 6, 11, 16), Seq(0, 1, 2)), + Row(2, 9.5, 9.5, 9.5, 9.5000, 4, 2, 2, 17, 17, 2, 17, 38, 38.0, 38.0, 38, + Seq(2, 7, 12, 17), Seq(0, 1, 2)), + Row(3, 10.5, 10.5, 10.5, 10.5000, 4, 3, 3, 18, 18, 3, 18, 42, 42.0, 42.0, 42, + Seq(3, 8, 13, 18), Seq(0, 1, 2)), + Row(4, 11.5, 11.5, 11.5, 11.5000, 4, 4, 4, 19, 19, 4, 19, 46, 46.0, 46.0, 46, + Seq(4, 9, 14, 19), Seq(0, 1, 2))) + ) + } +} From 179208ad7f5f3c85144f3550dfecfeb9208a4b13 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 3 Jun 2020 16:59:21 +0800 Subject: [PATCH 03/11] address comments --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 8 ++++---- .../state/StreamingAggregationStateManager.scala | 10 ++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c278a9c3d9fd..092bbdd12aa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1558,10 +1558,10 @@ object SQLConf { .booleanConf .createWithDefault(true) - val STREAMING_STATE_FORMAT_CHECK_ENABLED = - buildConf("spark.sql.streaming.stateFormatCheck.enabled") - .doc("Whether to detect a streaming query may try to use an invalid UnsafeRow in the " + - "state store.") + val STREAMING_AGGREGATION_STATE_FORMAT_CHECK_ENABLED = + buildConf("spark.sql.streaming.aggregationStateFormatCheck.enabled") + .doc("Whether to detect a streaming aggregation query may try to use an invalid UnsafeRow " + + "in the state store.") .version("3.1.0") .internal() .booleanConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala index f963b5c95f5e..124ccd623ea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -86,9 +86,11 @@ object StreamingAggregationStateManager extends Logging { * An exception thrown when an invalid UnsafeRow is detected. */ class InvalidUnsafeRowException - extends SparkException("The UnsafeRow format is invalid. This may happen when using the old " + - "version or broken checkpoint file. To resolve this problem, you can try to restart the " + - "application or use the legacy way to process streaming state.", null) + extends SparkException("The streaming aggregation query failed by state format invalidation. " + + "The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " + + "incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " + + "among restart. For the first case, you can try to restart the application without " + + "checkpoint or use the legacy Spark version to process the streaming state.", null) abstract class StreamingAggregationStateManagerBaseImpl( protected val keyExpressions: Seq[Attribute], @@ -113,7 +115,7 @@ abstract class StreamingAggregationStateManagerBaseImpl( override def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit = { if (checkFormat && SQLConf.get.getConf( - SQLConf.STREAMING_STATE_FORMAT_CHECK_ENABLED) && row != null) { + SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_CHECK_ENABLED) && row != null) { if (schema.fields.length != row.numFields) { throw new InvalidUnsafeRowException } From 4c919caa1507b69e4e5ab3955da97fea10794090 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 3 Jun 2020 17:20:27 +0800 Subject: [PATCH 04/11] address comments --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 092bbdd12aa3..addcb9d0eec8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1543,27 +1543,28 @@ object SQLConf { val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = buildConf("spark.sql.streaming.checkpointFileManagerClass") + .internal() .doc("The class used to write checkpoint files atomically. This class must be a subclass " + "of the interface CheckpointFileManager.") .version("2.4.0") - .internal() .stringConf val STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED = buildConf("spark.sql.streaming.checkpoint.escapedPathCheck.enabled") + .internal() .doc("Whether to detect a streaming query may pick up an incorrect checkpoint path due " + "to SPARK-26824.") .version("3.0.0") - .internal() .booleanConf .createWithDefault(true) val STREAMING_AGGREGATION_STATE_FORMAT_CHECK_ENABLED = buildConf("spark.sql.streaming.aggregationStateFormatCheck.enabled") - .doc("Whether to detect a streaming aggregation query may try to use an invalid UnsafeRow " + - "in the state store.") - .version("3.1.0") .internal() + .doc("When true, check if the UnsafeRow from the state store is valid or not when running " + + "streaming aggregation queries. This can happen if the state store format has been " + + "changed.") + .version("3.1.0") .booleanConf .createWithDefault(true) From b83f0c309f3c2dd02dcfc65f39d2ae23864b69c6 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 5 Jun 2020 22:42:30 +0800 Subject: [PATCH 05/11] Address comments:1.enhance the validation of unsafe row; 2.check the format for all state store --- .../sql/catalyst/util/UnsafeRowUtils.scala | 83 ++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 7 +- .../catalyst/util/UnsafeRowUtilsSuite.scala | 55 +++++++++ .../state/HDFSBackedStateStoreProvider.scala | 11 ++ .../streaming/state/StateStore.scala | 25 ++++- .../StreamingAggregationStateManager.scala | 48 +------- ...treamingAggregationStateManagerSuite.scala | 20 ---- ...reamingAggregationCompatibilitySuite.scala | 105 ------------------ 8 files changed, 178 insertions(+), 176 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationCompatibilitySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala new file mode 100644 index 000000000000..0a5e2e9a00fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types._ + +object UnsafeRowUtils { + + /** + * Use the following rules to check the integrity of the UnsafeRow: + * - schema.fields.length == row.numFields should always be true + * - UnsafeRow.calculateBitSetWidthInBytes(row.numFields) < row.getSizeInBytes should always be + * true if the expectedSchema contains at least one field. + * - For variable-length fields: if null bit says it's null then don't do anything, else extract + * offset and size: + * 1) 0 <= size < row.getSizeInBytes should always be true. We can be even more precise than + * this, where the upper bound of size can only be as big as the variable length part of + * the row. + * 2) offset should be >= fixed sized part of the row. + * 3) offset + size should be within the row bounds. + * - For fixed-length fields that are narrower than 8 bytes (boolean/byte/short/int/float), if + * null bit says it's null then don't do anything, else: + * check if the unused bits in the field are all zeros. The UnsafeRowWriter's write() methods + * make this guarantee. + * - Check the total length of the row. + */ + def validateStructuralIntegrity(row: UnsafeRow, expectedSchema: StructType): Boolean = { + if (expectedSchema.fields.length != row.numFields) { + return false + } + val bitSetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields) + val rowSizeInBytes = row.getSizeInBytes + if (expectedSchema.fields.length > 0 && bitSetWidthInBytes >= rowSizeInBytes) { + return false + } + var varLenFieldsSizeInBytes = 0 + expectedSchema.fields.zipWithIndex.foreach { + case (field, index) if !UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) => + val offsetAndSize = row.getLong(index) + val offset = (offsetAndSize >> 32).toInt + val size = offsetAndSize.toInt + if (size < 0 || + offset < bitSetWidthInBytes + 8 * row.numFields || offset + size > rowSizeInBytes) { + return false + } + varLenFieldsSizeInBytes += size + case (field, index) if UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) => + field.dataType match { + case BooleanType => + if ((row.getLong(index) >> 1) != 0L) return false + case ByteType => + if ((row.getLong(index) >> 8) != 0L) return false + case ShortType => + if ((row.getLong(index) >> 16) != 0L) return false + case IntegerType => + if ((row.getLong(index) >> 32) != 0L) return false + case FloatType => + if ((row.getLong(index) >> 32) != 0L) return false + case _ => + } + } + if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) { + return false + } + true + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index addcb9d0eec8..e92b913e67c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1558,12 +1558,11 @@ object SQLConf { .booleanConf .createWithDefault(true) - val STREAMING_AGGREGATION_STATE_FORMAT_CHECK_ENABLED = - buildConf("spark.sql.streaming.aggregationStateFormatCheck.enabled") + val STREAMING_STATE_FORMAT_CHECK_ENABLED = + buildConf("spark.sql.streaming.stateFormatCheck.enabled") .internal() .doc("When true, check if the UnsafeRow from the state store is valid or not when running " + - "streaming aggregation queries. This can happen if the state store format has been " + - "changed.") + "streaming queries. This can happen if the state store format has been changed.") .version("3.1.0") .booleanConf .createWithDefault(true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala new file mode 100644 index 000000000000..4b6a3cfafd89 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +class UnsafeRowUtilsSuite extends SparkFunSuite { + + val testKeys: Seq[String] = Seq("key1", "key2") + val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)") + + val testOutputSchema: StructType = StructType( + testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) + + val testRow: UnsafeRow = { + val unsafeRowProjection = UnsafeProjection.create(testOutputSchema) + val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema)) + (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } + row + } + + private def createIntegerField(name: String): StructField = { + StructField(name, IntegerType, nullable = false) + } + + test("UnsafeRow format invalidation") { + // Pass the checking + UnsafeRowUtils.validateStructuralIntegrity(testRow, testOutputSchema) + // Fail for fields number not match + assert(!UnsafeRowUtils.validateStructuralIntegrity( + testRow, StructType(testKeys.map(createIntegerField)))) + // Fail for invalid schema + val invalidSchema = StructType(testKeys.map(createIntegerField) ++ + Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true), + StructField("value2", IntegerType, false))) + assert(!UnsafeRowUtils.validateStructuralIntegrity(testRow, invalidSchema)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 05c651f9951b..8d6e4523eb37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -259,6 +259,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ @volatile private var numberOfVersionsToRetainInMemory: Int = _ + @volatile private var isValidated = false private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse) private lazy val baseDir = stateStoreId.storeCheckpointLocation() @@ -457,6 +458,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) + if (!isValidated) { + StateStoreProvider.validateStateRowFormat(keyRow, keySchema) + StateStoreProvider.validateStateRowFormat(valueRow, valueSchema) + isValidated = true + } map.put(keyRow, valueRow) } } @@ -551,6 +557,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) + if (!isValidated) { + StateStoreProvider.validateStateRowFormat(keyRow, keySchema) + StateStoreProvider.validateStateRowFormat(valueRow, valueSchema) + isValidated = true + } map.put(keyRow, valueRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7d80fd0c591f..39a924c7f1a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -27,10 +27,12 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, SparkEnv, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} @@ -143,6 +145,16 @@ case class StateStoreCustomSumMetric(name: String, desc: String) extends StateSt case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric +/** + * An exception thrown when an invalid UnsafeRow is detected in state store. + */ +class InvalidUnsafeRowException + extends SparkException("The streaming query failed by state format invalidation. " + + "The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " + + "incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " + + "among restart. For the first case, you can try to restart the application without " + + "checkpoint or use the legacy Spark version to process the streaming state.", null) + /** * Trait representing a provider that provide [[StateStore]] instances representing * versions of state data. @@ -230,6 +242,17 @@ object StateStoreProvider { provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) provider } + + /** + * Use the expected schema to check whether the UnsafeRow is valid. + */ + def validateStateRowFormat(row: UnsafeRow, schema: StructType): Unit = { + if (SQLConf.get.getConf(SQLConf.STREAMING_STATE_FORMAT_CHECK_ENABLED)) { + if (!UnsafeRowUtils.validateStructuralIntegrity(row, schema)) { + throw new InvalidUnsafeRowException + } + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala index 124ccd623ea8..9bfb9561b42a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** @@ -61,9 +59,6 @@ sealed trait StreamingAggregationStateManager extends Serializable { /** Return an iterator containing all the values in target state store. */ def values(store: StateStore): Iterator[UnsafeRow] - - /** Check the UnsafeRow format with the expected schema */ - def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit } object StreamingAggregationStateManager extends Logging { @@ -82,16 +77,6 @@ object StreamingAggregationStateManager extends Logging { } } -/** - * An exception thrown when an invalid UnsafeRow is detected. - */ -class InvalidUnsafeRowException - extends SparkException("The streaming aggregation query failed by state format invalidation. " + - "The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " + - "incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " + - "among restart. For the first case, you can try to restart the application without " + - "checkpoint or use the legacy Spark version to process the streaming state.", null) - abstract class StreamingAggregationStateManagerBaseImpl( protected val keyExpressions: Seq[Attribute], protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { @@ -99,9 +84,6 @@ abstract class StreamingAggregationStateManagerBaseImpl( @transient protected lazy val keyProjector = GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) - // Consider about the cost, only check the UnsafeRow format for the first row - private var checkFormat = true - override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) override def commit(store: StateStore): Long = store.commit() @@ -112,28 +94,6 @@ abstract class StreamingAggregationStateManagerBaseImpl( // discard and don't convert values to avoid computation store.getRange(None, None).map(_.key) } - - override def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit = { - if (checkFormat && SQLConf.get.getConf( - SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_CHECK_ENABLED) && row != null) { - if (schema.fields.length != row.numFields) { - throw new InvalidUnsafeRowException - } - schema.fields.zipWithIndex - .filterNot(field => UnsafeRow.isFixedLength(field._1.dataType)).foreach { - case (_, index) => - val offsetAndSize = row.getLong(index) - val offset = (offsetAndSize >> 32).toInt - val size = offsetAndSize.toInt - if (size < 0 || - offset < UnsafeRow.calculateBitSetWidthInBytes(row.numFields) + 8 * row.numFields || - offset + size > row.getSizeInBytes) { - throw new InvalidUnsafeRowException - } - } - checkFormat = false - } - } } /** @@ -154,9 +114,7 @@ class StreamingAggregationStateManagerImplV1( override def getStateValueSchema: StructType = inputRowAttributes.toStructType override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { - val res = store.get(key) - unsafeRowFormatValidation(res, inputRowAttributes.toStructType) - res + store.get(key) } override def put(store: StateStore, row: UnsafeRow): Unit = { @@ -215,9 +173,7 @@ class StreamingAggregationStateManagerImplV2( return savedState } - val res = restoreOriginalRow(key, savedState) - unsafeRowFormatValidation(res, inputRowAttributes.toStructType) - res + restoreOriginalRow(key, savedState) } override def put(store: StateStore, row: UnsafeRow): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index 2881e2e6f6c1..3145465e459b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -123,24 +123,4 @@ class StreamingAggregationStateManagerSuite extends StreamTest { // state manager should return row which is same as input row regardless of format version assert(inputRow === stateManager.get(memoryStateStore, keyRow)) } - - test("UnsafeRow format invalidation") { - // Pass the checking - val stateManager0 = StreamingAggregationStateManager.createStateManager(testKeyAttributes, - testOutputAttributes, 2) - stateManager0.unsafeRowFormatValidation(testRow, testOutputSchema) - // Fail for fields number not match - val stateManager1 = StreamingAggregationStateManager.createStateManager(testKeyAttributes, - testOutputAttributes, 2) - assertThrows[InvalidUnsafeRowException](stateManager1.unsafeRowFormatValidation( - testRow, StructType(testKeys.map(createIntegerField)))) - // Fail for invalid schema - val stateManager2 = StreamingAggregationStateManager.createStateManager(testKeyAttributes, - testOutputAttributes, 2) - val invalidSchema = StructType(testKeys.map(createIntegerField) ++ - Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true), - StructField("value2", IntegerType, false))) - assertThrows[InvalidUnsafeRowException](stateManager2.unsafeRowFormatValidation( - testRow, invalidSchema)) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationCompatibilitySuite.scala deleted file mode 100644 index 9ecbbad25ef7..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationCompatibilitySuite.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.io.File - -import org.apache.commons.io.FileUtils - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.functions._ -import org.apache.spark.util.Utils - -class StreamingAggregationCompatibilitySuite extends StreamTest { - import testImplicits._ - - test("common functions") { - val inputData = MemoryStream[Int] - - val aggregated = - inputData.toDF().toDF("value") - .selectExpr( - "value", - "value % 5 AS id", - "CAST(value AS STRING) as str", - "CAST(value AS FLOAT) as f", - "CAST(value AS DOUBLE) as d", - "CAST(value AS DECIMAL) as dec", - "value % 3 AS mod", - "named_struct('key', CAST(value AS STRING), 'value', value) AS s") - .groupBy($"id") - .agg( - avg($"value").as("avg_v"), - avg($"f").as("avg_f"), - avg($"d").as("avg_d"), - avg($"dec").as("avg_dec"), - count($"value").as("cnt"), - first($"value").as("first_v"), - first($"s").as("first_s"), - last($"value").as("last_v"), - last($"s").as("last_s"), - min(struct("value", "str")).as("min_struct"), - max($"value").as("max_v"), - sum($"value").as("sum_v"), - sum($"f").as("sum_f"), - sum($"d").as("sum_d"), - sum($"dec").as("sum_dec"), - collect_list($"value").as("col_list"), - collect_set($"mod").as("col_set")) - .select("id", "avg_v", "avg_f", "avg_d", "avg_dec", "cnt", "first_v", "first_s.value", - "last_v", "last_s.value", "min_struct.value", "max_v", "sum_v", "sum_f", "sum_d", - "sum_dec", "col_list", "col_set") - - val resourceUri = this.getClass.getResource("/structured-streaming/" + - "checkpoint-version-2.4.5-for-compatibility-test-common-functions").toURI - val checkpointDir = Utils.createTempDir().getCanonicalFile - FileUtils.copyDirectory(new File(resourceUri), checkpointDir) - - inputData.addData(0 to 9: _*) - - testStream(aggregated, Complete)( - StartStream(checkpointLocation = checkpointDir.getAbsolutePath), -// AddData(inputData, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9), -// CheckAnswer( -// Row(0, 2.5, 2.5F, 2.5, 2.5000, 2, 0, 0, 5, 5, 0, 5, 5, 5.0, 5.0, 5, Seq(0, 5), -// Seq(0, 2)), -// Row(1, 3.5, 3.5F, 3.5, 3.5000, 2, 1, 1, 6, 6, 1, 6, 7, 7.0, 7.0, 7, Seq(1, 6), -// Seq(0, 1)), -// Row(2, 4.5, 4.5F, 4.5, 4.5000, 2, 2, 2, 7, 7, 2, 7, 9, 9.0, 9.0, 9, Seq(2, 7), -// Seq(1, 2)), -// Row(3, 5.5, 5.5F, 5.5, 5.5000, 2, 3, 3, 8, 8, 3, 8, 11, 11.0, 11.0, 11, Seq(3, 8), -// Seq(0, 2)), -// Row(4, 6.5, 6.5F, 6.5, 6.5000, 2, 4, 4, 9, 9, 4, 9, 13, 13.0, 13.0, 13, Seq(4, 9), -// Seq(0, 1))), - AddData(inputData, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), - CheckAnswer( - Row(0, 7.5, 7.5, 7.5, 7.5000, 4, 0, 0, 15, 15, 0, 15, 30, 30.0, 30.0, 30, - Seq(0, 5, 10, 15), Seq(0, 1, 2)), - Row(1, 8.5, 8.5, 8.5, 8.5000, 4, 1, 1, 16, 16, 1, 16, 34, 34.0, 34.0, 34, - Seq(1, 6, 11, 16), Seq(0, 1, 2)), - Row(2, 9.5, 9.5, 9.5, 9.5000, 4, 2, 2, 17, 17, 2, 17, 38, 38.0, 38.0, 38, - Seq(2, 7, 12, 17), Seq(0, 1, 2)), - Row(3, 10.5, 10.5, 10.5, 10.5000, 4, 3, 3, 18, 18, 3, 18, 42, 42.0, 42.0, 42, - Seq(3, 8, 13, 18), Seq(0, 1, 2)), - Row(4, 11.5, 11.5, 11.5, 11.5000, 4, 4, 4, 19, 19, 4, 19, 46, 46.0, 46.0, 46, - Seq(4, 9, 14, 19), Seq(0, 1, 2))) - ) - } -} From 03130167feecb59e383fe7ee38862a263d5fb055 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sat, 6 Jun 2020 04:50:35 +0800 Subject: [PATCH 06/11] nit --- .../streaming/state/StreamingAggregationStateManagerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index 3145465e459b..daacdfd58c7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class StreamingAggregationStateManagerSuite extends StreamTest { // ============================ fields and method for test data ============================ From fc5ad199341da5974162496754af2de9863b5916 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 10 Jun 2020 20:33:42 +0800 Subject: [PATCH 07/11] temp for deduplicate --- .../org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala | 1 + .../streaming/state/HDFSBackedStateStoreProvider.scala | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index 0a5e2e9a00fa..c7e150a70f5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -74,6 +74,7 @@ object UnsafeRowUtils { if ((row.getLong(index) >> 32) != 0L) return false case _ => } + case _ => } if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) { return false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 8d6e4523eb37..03358a20fa63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -460,7 +460,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) if (!isValidated) { StateStoreProvider.validateStateRowFormat(keyRow, keySchema) - StateStoreProvider.validateStateRowFormat(valueRow, valueSchema) + // StateStoreProvider.validateStateRowFormat(valueRow, valueSchema) isValidated = true } map.put(keyRow, valueRow) @@ -559,7 +559,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) if (!isValidated) { StateStoreProvider.validateStateRowFormat(keyRow, keySchema) - StateStoreProvider.validateStateRowFormat(valueRow, valueSchema) + // StateStoreProvider.validateStateRowFormat(valueRow, valueSchema) isValidated = true } map.put(keyRow, valueRow) From 12eb2a256a8f33dcdb625e14bdd5e697d525187d Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 15 Jun 2020 17:02:56 +0800 Subject: [PATCH 08/11] Add config of value checking for deduplicate --- .../apache/spark/sql/internal/SQLConf.scala | 33 ++++++++++++++----- .../state/HDFSBackedStateStoreProvider.scala | 8 ++--- .../streaming/state/StateStore.scala | 15 +++++++-- .../streaming/state/StateStoreConf.scala | 6 ++++ .../streaming/statefulOperators.scala | 5 +++ 5 files changed, 51 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e92b913e67c9..316d061cd771 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1237,6 +1237,25 @@ object SQLConf { .intConf .createWithDefault(10) + val STATE_STORE_FORMAT_VALIDATION_ENABLED = + buildConf("spark.sql.streaming.stateStore.formatValidation.enabled") + .internal() + .doc("When true, check if the UnsafeRow from the state store is valid or not when running " + + "streaming queries. This can happen if the state store format has been changed.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE = + buildConf("spark.sql.streaming.stateStore.formatValidation.checkValue") + .internal() + .doc("When true, check if the value UnsafeRow from the state store is valid or not when " + + "running streaming queries. For some operations, we won't check the value format since " + + "the state store save fake values, e.g. Deduplicate.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") .internal() @@ -1558,15 +1577,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val STREAMING_STATE_FORMAT_CHECK_ENABLED = - buildConf("spark.sql.streaming.stateFormatCheck.enabled") - .internal() - .doc("When true, check if the UnsafeRow from the state store is valid or not when running " + - "streaming queries. This can happen if the state store format has been changed.") - .version("3.1.0") - .booleanConf - .createWithDefault(true) - val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled") .internal() @@ -2755,6 +2765,11 @@ class SQLConf extends Serializable with Logging { def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) + + def stateStoreFormatValidationCheckValue: Boolean = + getConf(STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 03358a20fa63..4329bbbbd12c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -459,8 +459,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) if (!isValidated) { - StateStoreProvider.validateStateRowFormat(keyRow, keySchema) - // StateStoreProvider.validateStateRowFormat(valueRow, valueSchema) + StateStoreProvider.validateStateRowFormat( + keyRow, keySchema, valueRow, valueSchema, storeConf) isValidated = true } map.put(keyRow, valueRow) @@ -558,8 +558,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) if (!isValidated) { - StateStoreProvider.validateStateRowFormat(keyRow, keySchema) - // StateStoreProvider.validateStateRowFormat(valueRow, valueSchema) + StateStoreProvider.validateStateRowFormat( + keyRow, keySchema, valueRow, valueSchema, storeConf) isValidated = true } map.put(keyRow, valueRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 39a924c7f1a7..171a54011681 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -246,9 +246,18 @@ object StateStoreProvider { /** * Use the expected schema to check whether the UnsafeRow is valid. */ - def validateStateRowFormat(row: UnsafeRow, schema: StructType): Unit = { - if (SQLConf.get.getConf(SQLConf.STREAMING_STATE_FORMAT_CHECK_ENABLED)) { - if (!UnsafeRowUtils.validateStructuralIntegrity(row, schema)) { + def validateStateRowFormat( + keyRow: UnsafeRow, + keySchema: StructType, + valueRow: UnsafeRow, + valueSchema: StructType, + conf: StateStoreConf): Unit = { + if (conf.formatValidationEnabled) { + if (!UnsafeRowUtils.validateStructuralIntegrity(keyRow, keySchema)) { + throw new InvalidUnsafeRowException + } + if (conf.formatValidationCheckValue && + !UnsafeRowUtils.validateStructuralIntegrity(valueRow, valueSchema)) { throw new InvalidUnsafeRowException } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index d145082a39b5..cb6ed7be3262 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -43,6 +43,12 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) */ val providerClass: String = sqlConf.stateStoreProviderClass + /** Whether validate the underlying format or not. */ + val formatValidationEnabled: Boolean = sqlConf.stateStoreFormatValidationEnabled + + /** Whether validate the value format when the format invalidation enabled. */ + val formatValidationCheckValue: Boolean = sqlConf.stateStoreFormatValidationCheckValue + /** * Additional configurations related to state store. This will capture all configs in * SQLConf that start with `spark.sql.streaming.stateStore.` */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 74daaf80b10e..ca1b27d58c70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} @@ -454,6 +455,10 @@ case class StreamingDeduplicateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW + // is unrelated to the output schema. + sqlContext.sessionState.conf.setConf(SQLConf.STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE, false) + child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, From 01007fb9f03c003bfc00d2e2358c9029b83f16e6 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 17 Jun 2020 18:57:48 +0800 Subject: [PATCH 09/11] change the config for specific operator only --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 ------------- .../streaming/state/StateStoreConf.scala | 15 +++++++++++---- .../execution/streaming/state/StateStoreRDD.scala | 5 +++-- .../sql/execution/streaming/state/package.scala | 3 ++- .../execution/streaming/statefulOperators.scala | 9 ++++----- 5 files changed, 20 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 316d061cd771..f8632cddb279 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1246,16 +1246,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE = - buildConf("spark.sql.streaming.stateStore.formatValidation.checkValue") - .internal() - .doc("When true, check if the value UnsafeRow from the state store is valid or not when " + - "running streaming queries. For some operations, we won't check the value format since " + - "the state store save fake values, e.g. Deduplicate.") - .version("3.1.0") - .booleanConf - .createWithDefault(true) - val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") .internal() @@ -2767,9 +2757,6 @@ class SQLConf extends Serializable with Logging { def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) - def stateStoreFormatValidationCheckValue: Boolean = - getConf(STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE) - def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index cb6ed7be3262..7736c2889b4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -class StateStoreConf(@transient private val sqlConf: SQLConf) +class StateStoreConf( + @transient private val sqlConf: SQLConf, + extraOptions: Map[String, String] = Map.empty) extends Serializable { def this() = this(new SQLConf) @@ -47,16 +49,21 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) val formatValidationEnabled: Boolean = sqlConf.stateStoreFormatValidationEnabled /** Whether validate the value format when the format invalidation enabled. */ - val formatValidationCheckValue: Boolean = sqlConf.stateStoreFormatValidationCheckValue + val formatValidationCheckValue: Boolean = + extraOptions.getOrElse(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG, "false") == "true" /** * Additional configurations related to state store. This will capture all configs in - * SQLConf that start with `spark.sql.streaming.stateStore.` */ + * SQLConf that start with `spark.sql.streaming.stateStore.` and extraOptions for a specific + * operator. + */ val confs: Map[String, String] = - sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) + sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) ++ extraOptions } object StateStoreConf { + val FORMAT_VALIDATION_CHECK_VALUE_CONFIG = "formatValidationCheckValue" + val empty = new StateStoreConf() def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 4a69a48fed75..0eb3dce1bbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -45,10 +45,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( valueSchema: StructType, indexOrdinal: Option[Int], sessionState: SessionState, - @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef], + extraOptions: Map[String, String] = Map.empty) extends RDD[U](dataRDD) { - private val storeConf = new StateStoreConf(sessionState.conf) + private val storeConf = new StateStoreConf(sessionState.conf, extraOptions) // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val hadoopConfBroadcast = dataRDD.context.broadcast( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b6021438e902..891a878a1d7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -55,7 +55,8 @@ package object state { valueSchema: StructType, indexOrdinal: Option[Int], sessionState: SessionState, - storeCoordinator: Option[StateStoreCoordinatorRef])( + storeCoordinator: Option[StateStoreCoordinatorRef], + extraOptions: Map[String, String] = Map.empty)( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index ca1b27d58c70..dc0d67d0cc20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -455,17 +455,16 @@ case class StreamingDeduplicateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW - // is unrelated to the output schema. - sqlContext.sessionState.conf.setConf(SQLConf.STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE, false) - child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + Some(sqlContext.streams.stateStoreCoordinator), + // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW + // is unrelated to the output schema. + Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "true")) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows") From fd74ff9c337d06f4cb4ccfc638d837b5ea3d0e11 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 17 Jun 2020 21:56:03 +0800 Subject: [PATCH 10/11] Address comments --- .../org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala | 2 ++ .../streaming/state/HDFSBackedStateStoreProvider.scala | 2 ++ .../spark/sql/execution/streaming/state/StateStore.scala | 3 +-- .../spark/sql/execution/streaming/state/StateStoreConf.scala | 2 +- .../apache/spark/sql/execution/streaming/state/package.scala | 3 ++- .../spark/sql/execution/streaming/statefulOperators.scala | 2 +- 6 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index c7e150a70f5c..50d8b419880e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -74,6 +74,8 @@ object UnsafeRowUtils { if ((row.getLong(index) >> 32) != 0L) return false case _ => } + case (field, index) if field.dataType == NullType => + if (!row.isNullAt(index) || row.getLong(index) != 0L) return false case _ => } if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 4329bbbbd12c..31618922e44c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -259,6 +259,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ @volatile private var numberOfVersionsToRetainInMemory: Int = _ + // TODO: The validation should be moved to a higher level so that it works for all state store + // implementations @volatile private var isValidated = false private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 171a54011681..092ca968f59c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -32,7 +32,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} @@ -149,7 +148,7 @@ case class StateStoreCustomTimingMetric(name: String, desc: String) extends Stat * An exception thrown when an invalid UnsafeRow is detected in state store. */ class InvalidUnsafeRowException - extends SparkException("The streaming query failed by state format invalidation. " + + extends RuntimeException("The streaming query failed by state format invalidation. " + "The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " + "incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " + "among restart. For the first case, you can try to restart the application without " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 7736c2889b4d..84d0b76ac915 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -50,7 +50,7 @@ class StateStoreConf( /** Whether validate the value format when the format invalidation enabled. */ val formatValidationCheckValue: Boolean = - extraOptions.getOrElse(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG, "false") == "true" + extraOptions.getOrElse(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG, "true") == "true" /** * Additional configurations related to state store. This will capture all configs in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 891a878a1d7b..c7a332b6d778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -79,7 +79,8 @@ package object state { valueSchema, indexOrdinal, sessionState, - storeCoordinator) + storeCoordinator, + extraOptions) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index dc0d67d0cc20..a9c01e69b9b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -464,7 +464,7 @@ case class StreamingDeduplicateExec( Some(sqlContext.streams.stateStoreCoordinator), // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW // is unrelated to the output schema. - Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "true")) { (store, iter) => + Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "false")) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows") From 557eb3099b3d0abe1fd2d7d91754fa747e05d200 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 18 Jun 2020 09:31:58 +0800 Subject: [PATCH 11/11] address comments --- .../org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala | 4 ++-- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index 50d8b419880e..37a34fac6636 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -74,8 +74,8 @@ object UnsafeRowUtils { if ((row.getLong(index) >> 32) != 0L) return false case _ => } - case (field, index) if field.dataType == NullType => - if (!row.isNullAt(index) || row.getLong(index) != 0L) return false + case (_, index) if row.isNullAt(index) => + if (row.getLong(index) != 0L) return false case _ => } if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f8632cddb279..6bbeb2de7538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1241,7 +1241,8 @@ object SQLConf { buildConf("spark.sql.streaming.stateStore.formatValidation.enabled") .internal() .doc("When true, check if the UnsafeRow from the state store is valid or not when running " + - "streaming queries. This can happen if the state store format has been changed.") + "streaming queries. This can happen if the state store format has been changed. Note, " + + "the feature is only effective in the build-in HDFS state store provider now.") .version("3.1.0") .booleanConf .createWithDefault(true)