1717
1818package org .apache .spark .sql .execution .streaming .state
1919
20+ import org .apache .spark .SparkException
2021import org .apache .spark .internal .Logging
2122import org .apache .spark .sql .catalyst .expressions .{Attribute , UnsafeRow }
2223import org .apache .spark .sql .catalyst .expressions .codegen .{GenerateUnsafeProjection , GenerateUnsafeRowJoiner }
24+ import org .apache .spark .sql .internal .SQLConf
2325import org .apache .spark .sql .types .StructType
2426
2527/**
@@ -59,6 +61,9 @@ sealed trait StreamingAggregationStateManager extends Serializable {
5961
6062 /** Return an iterator containing all the values in target state store. */
6163 def values (store : StateStore ): Iterator [UnsafeRow ]
64+
65+ /** Check the UnsafeRow format with the expected schema */
66+ def unsafeRowFormatValidation (row : UnsafeRow , schema : StructType ): Unit
6267}
6368
6469object StreamingAggregationStateManager extends Logging {
@@ -77,13 +82,24 @@ object StreamingAggregationStateManager extends Logging {
7782 }
7883}
7984
85+ /**
86+ * An exception thrown when an invalid UnsafeRow is detected.
87+ */
88+ class InvalidUnsafeRowException
89+ extends SparkException (" The UnsafeRow format is invalid. This may happen when using the old " +
90+ " version or broken checkpoint file. To resolve this problem, you can try to restart the " +
91+ " application or use the legacy way to process streaming state." , null )
92+
8093abstract class StreamingAggregationStateManagerBaseImpl (
8194 protected val keyExpressions : Seq [Attribute ],
8295 protected val inputRowAttributes : Seq [Attribute ]) extends StreamingAggregationStateManager {
8396
8497 @ transient protected lazy val keyProjector =
8598 GenerateUnsafeProjection .generate(keyExpressions, inputRowAttributes)
8699
100+ // Consider about the cost, only check the UnsafeRow format for the first row
101+ private var checkFormat = true
102+
87103 override def getKey (row : UnsafeRow ): UnsafeRow = keyProjector(row)
88104
89105 override def commit (store : StateStore ): Long = store.commit()
@@ -94,6 +110,28 @@ abstract class StreamingAggregationStateManagerBaseImpl(
94110 // discard and don't convert values to avoid computation
95111 store.getRange(None , None ).map(_.key)
96112 }
113+
114+ override def unsafeRowFormatValidation (row : UnsafeRow , schema : StructType ): Unit = {
115+ if (checkFormat && SQLConf .get.getConf(
116+ SQLConf .STREAMING_STATE_FORMAT_CHECK_ENABLED ) && row != null ) {
117+ if (schema.fields.length != row.numFields) {
118+ throw new InvalidUnsafeRowException
119+ }
120+ schema.fields.zipWithIndex
121+ .filterNot(field => UnsafeRow .isFixedLength(field._1.dataType)).foreach {
122+ case (_, index) =>
123+ val offsetAndSize = row.getLong(index)
124+ val offset = (offsetAndSize >> 32 ).toInt
125+ val size = offsetAndSize.toInt
126+ if (size < 0 ||
127+ offset < UnsafeRow .calculateBitSetWidthInBytes(row.numFields) + 8 * row.numFields ||
128+ offset + size > row.getSizeInBytes) {
129+ throw new InvalidUnsafeRowException
130+ }
131+ }
132+ checkFormat = false
133+ }
134+ }
97135}
98136
99137/**
@@ -114,7 +152,9 @@ class StreamingAggregationStateManagerImplV1(
114152 override def getStateValueSchema : StructType = inputRowAttributes.toStructType
115153
116154 override def get (store : StateStore , key : UnsafeRow ): UnsafeRow = {
117- store.get(key)
155+ val res = store.get(key)
156+ unsafeRowFormatValidation(res, inputRowAttributes.toStructType)
157+ res
118158 }
119159
120160 override def put (store : StateStore , row : UnsafeRow ): Unit = {
@@ -173,7 +213,9 @@ class StreamingAggregationStateManagerImplV2(
173213 return savedState
174214 }
175215
176- restoreOriginalRow(key, savedState)
216+ val res = restoreOriginalRow(key, savedState)
217+ unsafeRowFormatValidation(res, inputRowAttributes.toStructType)
218+ res
177219 }
178220
179221 override def put (store : StateStore , row : UnsafeRow ): Unit = {
0 commit comments