Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
7dec5eb
[SPARK-47705][INFRA] Sort LogKey alphabetically and build a test to e…
dtenedor Apr 3, 2024
6a0555c
[SPARK-47700][SQL] Fix formatting of error messages with treeNode
jchen5 Apr 3, 2024
49eefc5
[SPARK-47722][SS] Wait until RocksDB background work finish before cl…
WweiL Apr 3, 2024
fbe6b1d
[SPARK-47721][DOC] Guidelines for the Structured Logging Framework
gengliangwang Apr 3, 2024
e3aab8c
[SPARK-47210][SQL] Addition of implicit casting without indeterminate…
mihailomilosevic2001 Apr 3, 2024
d87ac8e
[SPARK-47708][CONNECT] Do not log gRPC exception to stderr in PySpark
nemanja-boric-databricks Apr 4, 2024
447f8af
[SPARK-47720][CORE] Update `spark.speculation.multiplier` to 3 and `s…
dongjoon-hyun Apr 4, 2024
678aeb7
[SPARK-47683][PYTHON][BUILD] Decouple PySpark core API to pyspark.cor…
HyukjinKwon Apr 4, 2024
c25fd93
[SPARK-47705][INFRA][FOLLOWUP] Sort LogKey alphabetically and build a…
panbingkun Apr 4, 2024
d272a1b
[SPARK-47724][PYTHON][TESTS] Add an environment variable for testing …
HyukjinKwon Apr 4, 2024
d75c775
[SPARK-46812][PYTHON][TESTS][FOLLOWUP] Skip `pandas`-required tests i…
dongjoon-hyun Apr 4, 2024
3f6ac60
[SPARK-47577][CORE][PART1] Migrate logError with variables to structu…
gengliangwang Apr 4, 2024
f6999df
[SPARK-47081][CONNECT] Support Query Execution Progress
grundprinzip Apr 4, 2024
bffb02d
[SPARK-47565][PYTHON] PySpark worker pool crash resilience
Apr 4, 2024
3b8aea3
Revert "[SPARK-47708][CONNECT] Do not log gRPC exception to stderr in…
nemanja-boric-databricks Apr 4, 2024
5f9f5db
[SPARK-47689][SQL][FOLLOWUP] More accurate file path in TASK_WRITE_FA…
cloud-fan Apr 4, 2024
5ca3467
[SPARK-47729][PYTHON][TESTS] Get the proper default port for pyspark-…
HyukjinKwon Apr 4, 2024
25fc67f
[SPARK-47728][DOC] Document G1 Concurrent GC metrics
LucaCanali Apr 4, 2024
e3405c1
[SPARK-47610][CONNECT][FOLLOWUP] Add -Dio.netty.tryReflectionSetAcces…
pan3793 Apr 4, 2024
3fd0cd6
[SPARK-47598][CORE] MLLib: Migrate logError with variables to structu…
panbingkun Apr 4, 2024
240923c
[SPARK-46812][PYTHON][TESTS][FOLLOWUP] Check should_test_connect and …
dongjoon-hyun Apr 4, 2024
fb96b1a
[SPARK-47723][CORE][TESTS] Introduce a tool that can sort alphabetica…
panbingkun Apr 5, 2024
404d58c
[SPARK-47081][CONNECT][FOLLOW-UP] Add the `shell` module into PyPI pa…
HyukjinKwon Apr 5, 2024
b9ca91d
[SPARK-47712][CONNECT] Allow connect plugins to create and process Da…
tomvanbussel Apr 5, 2024
0107435
[SPARK-47734][PYTHON][TESTS] Fix flaky DataFrame.writeStream doctest …
JoshRosen Apr 5, 2024
d5620cb
[SPARK-47289][SQL] Allow extensions to log extended information in ex…
parthchandra Apr 5, 2024
aeb082e
[SPARK-47081][CONNECT][TESTS][FOLLOW-UP] Skip the flaky doctests for now
HyukjinKwon Apr 5, 2024
97e63ff
[SPARK-47735][PYTHON][TESTS] Make pyspark.testing.connectutils compat…
HyukjinKwon Apr 5, 2024
12d0367
[SPARK-47724][PYTHON][TESTS][FOLLOW-UP] Make testing script to inheri…
HyukjinKwon Apr 5, 2024
6bd0ccf
[SPARK-47511][SQL][FOLLOWUP] Rename the config REPLACE_NULLIF_USING_W…
cloud-fan Apr 5, 2024
c34baeb
[SPARK-47719][SQL] Change spark.sql.legacy.timeParserPolicy default t…
srielau Apr 5, 2024
18072b5
[SPARK-47577][CORE][PART2] Migrate logError with variables to structu…
gengliangwang Apr 5, 2024
1efbf43
[SPARK-47310][SS] Add micro-benchmark for merge operations for multip…
anishshri-db Apr 5, 2024
d1ace24
[SPARK-47582][SQL] Migrate Catalyst logInfo with variables to structu…
dtenedor Apr 5, 2024
11abc64
[SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when…
szehon-ho Apr 6, 2024
42dc815
[SPARK-47743][CORE] Use milliseconds as the time unit in logging
gengliangwang Apr 6, 2024
7385f19
[SPARK-47592][CORE] Connector module: Migrate logError with variables…
panbingkun Apr 6, 2024
d69df59
[SPARK-47738][BUILD] Upgrade Kafka to 3.7.0
panbingkun Apr 6, 2024
60a3fbc
[SPARK-47727][PYTHON] Make SparkConf to root level to for both SparkS…
HyukjinKwon Apr 6, 2024
644687b
[SPARK-47709][BUILD] Upgrade tink to 1.13.0
LuciferYang Apr 6, 2024
4d9dbb3
[SPARK-46722][CONNECT][SS][TESTS][FOLLOW-UP] Drop the tables after te…
HyukjinKwon Apr 7, 2024
c11585a
[SPARK-47751][PYTHON][CONNECT] Make pyspark.worker_utils compatible w…
HyukjinKwon Apr 7, 2024
d743012
[SPARK-47753][PYTHON][CONNECT][TESTS] Make pyspark.testing compatible…
HyukjinKwon Apr 7, 2024
f7dff4a
[SPARK-47752][PS][CONNECT] Make pyspark.pandas compatible with pyspar…
HyukjinKwon Apr 7, 2024
e92e8f5
[SPARK-47744] Add support for negative-valued bytes in range encoder
neilramaswamy Apr 7, 2024
0c992b2
[SPARK-47755][CONNECT] Pivot should fail when the number of distinct …
zhengruifeng Apr 7, 2024
b299b2b
[SPARK-47299][PYTHON][DOCS] Use the same `versions.json` in the dropd…
panbingkun Apr 8, 2024
cc6c0eb
[MINOR][TESTS] Deduplicate test cases `test_parse_datatype_string`
HyukjinKwon Apr 8, 2024
ad2367c
[MINOR][PYTHON][SS][TESTS] Drop the tables after being used at `test_…
HyukjinKwon Apr 8, 2024
f576b85
[SPARK-47541][SQL] Collated strings in complex types supporting opera…
nikolamand-db Apr 8, 2024
d55bb61
[SPARK-47558][SS] State TTL support for ValueState
sahnib Apr 8, 2024
3a39ac2
[SPARK-47713][SQL][CONNECT] Fix a self-join failure
zhengruifeng Apr 8, 2024
eb8e997
[SPARK-47657][SQL] Implement collation filter push down support per f…
stefankandic Apr 8, 2024
f0d8f82
[SPARK-47750][DOCS][SQL] Postgres: Document Mapping Spark SQL Data Ty…
yaooqinn Apr 8, 2024
211afd4
[MINOR][PYTHON][CONNECT][TESTS] Enable `MapInPandasParityTests.test_d…
zhengruifeng Apr 8, 2024
f94d95d
[SPARK-47762][PYTHON][CONNECT] Add pyspark.sql.connect.protobuf into …
HyukjinKwon Apr 8, 2024
29d077f
[SPARK-47748][BUILD] Upgrade `zstd-jni` to 1.5.6-2
panbingkun Apr 8, 2024
60806c6
[SPARK-47746] Implement ordinal-based range encoding in the RocksDBSt…
neilramaswamy Apr 8, 2024
134a139
[SPARK-47681][SQL] Add schema_of_variant expression
chenhao-db Apr 8, 2024
abb7b04
[SPARK-47504][SQL] Resolve AbstractDataType simpleStrings for StringT…
mihailomilosevic2001 Apr 8, 2024
91b2331
[WIP] ListStateTTL implementation
ericm-db Apr 8, 2024
479392a
adding log lines
ericm-db Apr 8, 2024
7aab43e
test cases pass
ericm-db Apr 8, 2024
71f960d
spacing
ericm-db Apr 8, 2024
998764c
using NextIterator instead
ericm-db Apr 8, 2024
1dcb7d8
refactor feedback
ericm-db Apr 9, 2024
47867e7
undoing unnecessary change
ericm-db Apr 9, 2024
cfd30c3
refactor get_ttl_value
ericm-db Apr 9, 2024
4a19cb7
refactor test case
ericm-db Apr 9, 2024
993125c
specific doc for clearIfExpired
ericm-db Apr 9, 2024
fd5200f
moving isExpired to common place
ericm-db Apr 9, 2024
d43ffb1
refactoring to use common utils
ericm-db Apr 9, 2024
30f6094
updating interface header
ericm-db Apr 9, 2024
e9376d9
Map State TTL, Initial Commit
ericm-db Apr 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[WIP] ListStateTTL implementation
  • Loading branch information
ericm-db committed Apr 8, 2024
commit 91b2331c1fe34dfde14107a27739e0a27eac280c
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
*/
def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T]

/**
* Function to create new or return existing list state variable of given type
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
* from the state store. Any state update resets the ttl to current processing time plus
* ttlDuration.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
* @param stateName - name of the state variable
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @tparam T - type of state variable
* @return - instance of ValueState of type T that can be used to store state persistently
*/
def getListState[T](
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T]

/**
* Creates new or returns existing map state associated with stateName.
* The MapState persists Key-Value pairs of type [K, V].
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}

/**
* Provides concrete implementation for list of values associated with a state variable
* used in the streaming transformWithState operator.
*
* @param store - reference to the StateStore instance to be used for storing state
* @param stateName - name of logical state partition
* @param keyEnc - Spark SQL encoder for key
* @param valEncoder - Spark SQL encoder for value
* @tparam S - data type of object that will be stored in the list
*/
class ListStateImplWithTTL[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
valEncoder: Encoder[S],
ttlConfig: TTLConfig,
batchTimestampMs: Long)
extends SingleKeyTTLStateImpl(stateName, store, batchTimestampMs) with ListState[S] with Logging {

private val keySerializer = keyExprEnc.createSerializer()

private val stateTypesEncoder = StateTypesEncoder(
keySerializer, valEncoder, stateName, hasTtl = true)

private val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

initialize()

private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true)
}
/** Whether state exists or not. */
override def exists(): Boolean = {
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val stateValue = store.get(encodedGroupingKey, stateName)
stateValue != null
}

/**
* Get the state value if it exists. If the state does not exist in state store, an
* empty iterator is returned.
*/
override def get(): Iterator[S] = {
val encodedKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
logError(s"### size: ${store.valuesIterator(encodedKey, stateName).size}")
var currentRow: UnsafeRow = null

new Iterator[S] {
override def hasNext: Boolean = {
if (currentRow == null) {
setNextValidRow()
}

currentRow != null
}

override def next(): S = {
if (currentRow == null) {
setNextValidRow()
}
if (currentRow == null) {
throw new NoSuchElementException("Iterator is at the end")
}
val result = stateTypesEncoder.decodeValue(currentRow)
currentRow = null
result
}

// sets currentRow to a valid state, where we are
// pointing to a non-expired row
private def setNextValidRow(): Unit = {
logError(s"### setNextValidRow")
assert(currentRow == null)
if (unsafeRowValuesIterator.hasNext) {
logError(s"### set currentRow")
currentRow = unsafeRowValuesIterator.next()
} else {
currentRow = null
return
}
while (unsafeRowValuesIterator.hasNext && isExpired(currentRow)) {
logError(s"### isExpired, ${isExpired(currentRow)}")
currentRow = unsafeRowValuesIterator.next()
}

// in this case, we have iterated to the end, and there are no
// non-expired values
if (currentRow != null && isExpired(currentRow)) {
currentRow = null
}
}
}
}

/** Update the value of the list. */
override def put(newState: Array[S]): Unit = {
validateNewState(newState)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
var isFirst = true

newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
if (isFirst) {
store.put(encodedKey, encodedValue, stateName)
isFirst = false
} else {
store.merge(encodedKey, encodedValue, stateName)
}
}
val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey)
}

/** Append an entry to the list. */
override def appendValue(newState: S): Unit = {
StateStoreErrors.requireNonNullStateValue(newState, stateName)
store.merge(stateTypesEncoder.encodeGroupingKey(),
stateTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName)
val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey)
}

/** Append an entire list to the existing value. */
override def appendList(newState: Array[S]): Unit = {
validateNewState(newState)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
store.merge(encodedKey, encodedValue, stateName)
}
val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey)
}

/** Remove this state. */
override def clear(): Unit = {
store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
}

private def validateNewState(newState: Array[S]): Unit = {
StateStoreErrors.requireNonNullStateValue(newState, stateName)
StateStoreErrors.requireNonEmptyListStateValue(newState, stateName)

newState.foreach { v =>
StateStoreErrors.requireNonNullStateValue(v, stateName)
}
}

/**
* Clears the user state associated with this grouping key
* if it has expired. This function is called by Spark to perform
* cleanup at the end of transformWithState processing.
*
* Spark uses a secondary index to determine if the user state for
* this grouping key has expired. However, its possible that the user
* has updated the TTL and secondary index is out of date. Implementations
* must validate that the user State has actually expired before cleanup based
* on their own State data.
*
* @param groupingKey grouping key for which cleanup should be performed.
*/
override def clearIfExpired(groupingKey: Array[Byte]): Unit = {
logError(s"### clearIfExpired")
val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey)
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
// We clear the list, and use the iterator to put back all of the non-expired values
store.remove(encodedGroupingKey, stateName)
var isFirst = true
unsafeRowValuesIterator.foreach { encodedValue =>
if (!isExpired(encodedValue)) {
if (isFirst) {
store.put(encodedGroupingKey, encodedValue, stateName)
isFirst = false
} else {
store.merge(encodedGroupingKey, encodedValue, stateName)
}
}
}
}

private def isExpired(valueRow: UnsafeRow): Boolean = {
logError(s"### isExpired, ${batchTimestampMs}")
val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow)
expirationMs.exists(StateTTL.isExpired(_, batchTimestampMs))
}
/*
* Internal methods to probe state for testing. The below methods exist for unit tests
* to read the state ttl values, and ensure that values are persisted correctly in
* the underlying state store.
*/

/**
* Retrieves the value from State even if its expired. This method is used
* in tests to read the state store value, and ensure if its cleaned up at the
* end of the micro-batch.
*/
private[sql] def getWithoutEnforcingTTL(): Iterator[Option[S]] = {
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
new Iterator[Option[S]] {
override def hasNext: Boolean = {
unsafeRowValuesIterator.hasNext
}

override def next(): Option[S] = {
val valueUnsafeRow = unsafeRowValuesIterator.next()
if (valueUnsafeRow != null) {
val resState = stateTypesEncoder.decodeValue(valueUnsafeRow)
Some(resState)
} else {
None
}
}
}
}

/**
* Read the ttl value associated with the grouping key.
*/
private[sql] def getTTLValues(): Iterator[Option[Long]] = {
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
new Iterator[Option[Long]] {
override def hasNext: Boolean = {
unsafeRowValuesIterator.hasNext
}

override def next(): Option[Long] = {
val valueUnsafeRow = unsafeRowValuesIterator.next()
stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow)
}
}
}
/**
* Get all ttl values stored in ttl state for current implicit
* grouping key.
*/
private[sql] def getValuesInTTLState(): Iterator[Long] = {
val ttlIterator = ttlIndexIterator()
val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey()
var nextValue: Option[Long] = None

new Iterator[Long] {
override def hasNext: Boolean = {
while (nextValue.isEmpty && ttlIterator.hasNext) {
val nextTtlValue = ttlIterator.next()
val groupingKey = nextTtlValue.groupingKey
if (groupingKey sameElements implicitGroupingKey) {
nextValue = Some(nextTtlValue.expirationMs)
}
}
nextValue.isDefined
}

override def next(): Long = {
val result = nextValue.get
nextValue = None
result
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,35 @@ class StatefulProcessorHandleImpl(
throw StateStoreErrors.ttlMustBePositive("update", stateName)
}
}

/**
* Function to create new or return existing list state variable of given type
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
* from the state store. Any state update resets the ttl to current processing time plus
* ttlDuration.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
* @param stateName - name of the state variable
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @tparam T - type of state variable
* @return - instance of ValueState of type T that can be used to store state persistently
*/
override def getListState[T](
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T] = {

verifyStateVarOperations("get_list_state")
validateTTLConfig(ttlConfig, stateName)

assert(batchTimestampMs.isDefined)
val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get)
ttlStates.add(listStateWithTTL)

listStateWithTTL
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming

import java.time.Duration

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec, StateStore}
Expand Down Expand Up @@ -81,7 +82,7 @@ abstract class SingleKeyTTLStateImpl(
stateName: String,
store: StateStore,
ttlExpirationMs: Long)
extends TTLState {
extends TTLState with Logging {

import org.apache.spark.sql.execution.streaming.StateTTLSchema._

Expand All @@ -98,6 +99,7 @@ abstract class SingleKeyTTLStateImpl(
def upsertTTLForStateKey(
expirationMs: Long,
groupingKey: Array[Byte]): Unit = {
logError(s"### upsertTTLForStateKey: expirationMs=$expirationMs")
val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey))
store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName)
}
Expand Down Expand Up @@ -138,7 +140,7 @@ abstract class SingleKeyTTLStateImpl(
/**
* Helper methods for user State TTL.
*/
object StateTTL {
object StateTTL extends Logging {
def calculateExpirationTimeForDuration(
ttlDuration: Duration,
batchTtlExpirationMs: Long): Long = {
Expand All @@ -148,6 +150,8 @@ object StateTTL {
def isExpired(
expirationMs: Long,
batchTtlExpirationMs: Long): Boolean = {
logError(s"### isExpired: expirationMs=$expirationMs," +
s" batchTtlExpirationMs=$batchTtlExpirationMs")
batchTtlExpirationMs >= expirationMs
}
}
Loading