Skip to content
Prev Previous commit
Next Next commit
Add config of value checking for deduplicate
  • Loading branch information
xuanyuanking committed Jun 17, 2020
commit 12eb2a256a8f33dcdb625e14bdd5e697d525187d
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,25 @@ object SQLConf {
.intConf
.createWithDefault(10)

val STATE_STORE_FORMAT_VALIDATION_ENABLED =
buildConf("spark.sql.streaming.stateStore.formatValidation.enabled")
.internal()
.doc("When true, check if the UnsafeRow from the state store is valid or not when running " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change UnsafeRow to checkpoint ? Most end users do not know what are UnsafeRow

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will submit a follow-up PR today.

"streaming queries. This can happen if the state store format has been changed.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

val STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE =
buildConf("spark.sql.streaming.stateStore.formatValidation.checkValue")
.internal()
.doc("When true, check if the value UnsafeRow from the state store is valid or not when " +
"running streaming queries. For some operations, we won't check the value format since " +
"the state store save fake values, e.g. Deduplicate.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
.internal()
Expand Down Expand Up @@ -1558,15 +1577,6 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val STREAMING_STATE_FORMAT_CHECK_ENABLED =
buildConf("spark.sql.streaming.stateFormatCheck.enabled")
.internal()
.doc("When true, check if the UnsafeRow from the state store is valid or not when running " +
"streaming queries. This can happen if the state store format has been changed.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION =
buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled")
.internal()
Expand Down Expand Up @@ -2755,6 +2765,11 @@ class SQLConf extends Serializable with Logging {

def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)

def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)

def stateStoreFormatValidationCheckValue: Boolean =
getConf(STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE)

def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)

def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
if (!isValidated) {
StateStoreProvider.validateStateRowFormat(keyRow, keySchema)
// StateStoreProvider.validateStateRowFormat(valueRow, valueSchema)
StateStoreProvider.validateStateRowFormat(
keyRow, keySchema, valueRow, valueSchema, storeConf)
isValidated = true
}
map.put(keyRow, valueRow)
Expand Down Expand Up @@ -558,8 +558,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
if (!isValidated) {
StateStoreProvider.validateStateRowFormat(keyRow, keySchema)
// StateStoreProvider.validateStateRowFormat(valueRow, valueSchema)
StateStoreProvider.validateStateRowFormat(
keyRow, keySchema, valueRow, valueSchema, storeConf)
isValidated = true
}
map.put(keyRow, valueRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,18 @@ object StateStoreProvider {
/**
* Use the expected schema to check whether the UnsafeRow is valid.
*/
def validateStateRowFormat(row: UnsafeRow, schema: StructType): Unit = {
if (SQLConf.get.getConf(SQLConf.STREAMING_STATE_FORMAT_CHECK_ENABLED)) {
if (!UnsafeRowUtils.validateStructuralIntegrity(row, schema)) {
def validateStateRowFormat(
keyRow: UnsafeRow,
keySchema: StructType,
valueRow: UnsafeRow,
valueSchema: StructType,
conf: StateStoreConf): Unit = {
if (conf.formatValidationEnabled) {
if (!UnsafeRowUtils.validateStructuralIntegrity(keyRow, keySchema)) {
throw new InvalidUnsafeRowException
}
if (conf.formatValidationCheckValue &&
!UnsafeRowUtils.validateStructuralIntegrity(valueRow, valueSchema)) {
throw new InvalidUnsafeRowException
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class StateStoreConf(@transient private val sqlConf: SQLConf)
*/
val providerClass: String = sqlConf.stateStoreProviderClass

/** Whether validate the underlying format or not. */
val formatValidationEnabled: Boolean = sqlConf.stateStoreFormatValidationEnabled

/** Whether validate the value format when the format invalidation enabled. */
val formatValidationCheckValue: Boolean = sqlConf.stateStoreFormatValidationCheckValue

/**
* Additional configurations related to state store. This will capture all configs in
* SQLConf that start with `spark.sql.streaming.stateStore.` */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
import org.apache.spark.util.{CompletionIterator, NextIterator, Utils}
Expand Down Expand Up @@ -454,6 +455,10 @@ case class StreamingDeduplicateExec(
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

// We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW
// is unrelated to the output schema.
sqlContext.sessionState.conf.setConf(SQLConf.STATE_STORE_FORMAT_VALIDATION_CHECK_VALUE, false)

child.execute().mapPartitionsWithStateStore(
getStateInfo,
keyExpressions.toStructType,
Expand Down