Skip to content

Commit 4d14961

Browse files
committed
initial commit
1 parent e5c3463 commit 4d14961

File tree

3 files changed

+74
-3
lines changed

3 files changed

+74
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,15 @@ object SQLConf {
15451545
.booleanConf
15461546
.createWithDefault(true)
15471547

1548+
val STREAMING_STATE_FORMAT_CHECK_ENABLED =
1549+
buildConf("spark.sql.streaming.stateFormatCheck.enabled")
1550+
.doc("Whether to detect a streaming query may try to use an invalid UnsafeRow in the " +
1551+
"state store.")
1552+
.version("3.1.0")
1553+
.internal()
1554+
.booleanConf
1555+
.createWithDefault(true)
1556+
15481557
val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION =
15491558
buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled")
15501559
.internal()

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.sql.execution.streaming.state
1919

20+
import org.apache.spark.SparkException
2021
import org.apache.spark.internal.Logging
2122
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
2223
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
24+
import org.apache.spark.sql.internal.SQLConf
2325
import org.apache.spark.sql.types.StructType
2426

2527
/**
@@ -59,6 +61,9 @@ sealed trait StreamingAggregationStateManager extends Serializable {
5961

6062
/** Return an iterator containing all the values in target state store. */
6163
def values(store: StateStore): Iterator[UnsafeRow]
64+
65+
/** Check the UnsafeRow format with the expected schema */
66+
def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit
6267
}
6368

6469
object StreamingAggregationStateManager extends Logging {
@@ -77,13 +82,24 @@ object StreamingAggregationStateManager extends Logging {
7782
}
7883
}
7984

85+
/**
86+
* An exception thrown when an invalid UnsafeRow is detected.
87+
*/
88+
class InvalidUnsafeRowException
89+
extends SparkException("The UnsafeRow format is invalid. This may happen when using the old " +
90+
"version or broken checkpoint file. To resolve this problem, you can try to restart the " +
91+
"application or use the legacy way to process streaming state.", null)
92+
8093
abstract class StreamingAggregationStateManagerBaseImpl(
8194
protected val keyExpressions: Seq[Attribute],
8295
protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager {
8396

8497
@transient protected lazy val keyProjector =
8598
GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes)
8699

100+
// Consider about the cost, only check the UnsafeRow format for the first row
101+
private var checkFormat = true
102+
87103
override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row)
88104

89105
override def commit(store: StateStore): Long = store.commit()
@@ -94,6 +110,28 @@ abstract class StreamingAggregationStateManagerBaseImpl(
94110
// discard and don't convert values to avoid computation
95111
store.getRange(None, None).map(_.key)
96112
}
113+
114+
override def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit = {
115+
if (checkFormat && SQLConf.get.getConf(
116+
SQLConf.STREAMING_STATE_FORMAT_CHECK_ENABLED) && row != null) {
117+
if (schema.fields.length != row.numFields) {
118+
throw new InvalidUnsafeRowException
119+
}
120+
schema.fields.zipWithIndex
121+
.filterNot(field => UnsafeRow.isFixedLength(field._1.dataType)).foreach {
122+
case (_, index) =>
123+
val offsetAndSize = row.getLong(index)
124+
val offset = (offsetAndSize >> 32).toInt
125+
val size = offsetAndSize.toInt
126+
if (size < 0 ||
127+
offset < UnsafeRow.calculateBitSetWidthInBytes(row.numFields) + 8 * row.numFields ||
128+
offset + size > row.getSizeInBytes) {
129+
throw new InvalidUnsafeRowException
130+
}
131+
}
132+
checkFormat = false
133+
}
134+
}
97135
}
98136

99137
/**
@@ -114,7 +152,9 @@ class StreamingAggregationStateManagerImplV1(
114152
override def getStateValueSchema: StructType = inputRowAttributes.toStructType
115153

116154
override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
117-
store.get(key)
155+
val res = store.get(key)
156+
unsafeRowFormatValidation(res, inputRowAttributes.toStructType)
157+
res
118158
}
119159

120160
override def put(store: StateStore, row: UnsafeRow): Unit = {
@@ -173,7 +213,9 @@ class StreamingAggregationStateManagerImplV2(
173213
return savedState
174214
}
175215

176-
restoreOriginalRow(key, savedState)
216+
val res = restoreOriginalRow(key, savedState)
217+
unsafeRowFormatValidation(res, inputRowAttributes.toStructType)
218+
res
177219
}
178220

179221
override def put(store: StateStore, row: UnsafeRow): Unit = {

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
2020
import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow}
2121
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
2222
import org.apache.spark.sql.streaming.StreamTest
23-
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
23+
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
2424

2525
class StreamingAggregationStateManagerSuite extends StreamTest {
2626
// ============================ fields and method for test data ============================
@@ -123,4 +123,24 @@ class StreamingAggregationStateManagerSuite extends StreamTest {
123123
// state manager should return row which is same as input row regardless of format version
124124
assert(inputRow === stateManager.get(memoryStateStore, keyRow))
125125
}
126+
127+
test("UnsafeRow format invalidation") {
128+
// Pass the checking
129+
val stateManager0 = StreamingAggregationStateManager.createStateManager(testKeyAttributes,
130+
testOutputAttributes, 2)
131+
stateManager0.unsafeRowFormatValidation(testRow, testOutputSchema)
132+
// Fail for fields number not match
133+
val stateManager1 = StreamingAggregationStateManager.createStateManager(testKeyAttributes,
134+
testOutputAttributes, 2)
135+
assertThrows[InvalidUnsafeRowException](stateManager1.unsafeRowFormatValidation(
136+
testRow, StructType(testKeys.map(createIntegerField))))
137+
// Fail for invalid schema
138+
val stateManager2 = StreamingAggregationStateManager.createStateManager(testKeyAttributes,
139+
testOutputAttributes, 2)
140+
val invalidSchema = StructType(testKeys.map(createIntegerField) ++
141+
Seq(StructField("struct", StructType(Seq(StructField("value1", StringType, true))), true),
142+
StructField("value2", IntegerType, false)))
143+
assertThrows[InvalidUnsafeRowException](stateManager2.unsafeRowFormatValidation(
144+
testRow, invalidSchema))
145+
}
126146
}

0 commit comments

Comments
 (0)