-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24763][SS] Remove redundant key data from value in streaming aggregation #21733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4252f41
941b88d
abec57f
977428c
63dfb5d
e844636
26701a3
60c231e
b4a3807
e0ee04a
8629f59
65801a6
19888ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
* TODO list * add docs
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,9 +89,13 @@ package object state { | |
| sealed trait StreamingAggregationStateManager extends Serializable { | ||
| def getKey(row: InternalRow): UnsafeRow | ||
| def getStateValueSchema: StructType | ||
| def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow | ||
| def get(store: StateStore, key: UnsafeRow): UnsafeRow | ||
| def put(store: StateStore, row: UnsafeRow): Unit | ||
| def commit(store: StateStore): Long | ||
| def remove(store: StateStore, key: UnsafeRow): Unit | ||
| def iterator(store: StateStore): Iterator[UnsafeRowPair] | ||
| def keys(store: StateStore): Iterator[UnsafeRow] | ||
| def values(store: StateStore): Iterator[UnsafeRow] | ||
| } | ||
|
|
||
| object StreamingAggregationStateManager extends Logging { | ||
|
|
@@ -118,6 +122,15 @@ package object state { | |
| GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) | ||
|
|
||
| def getKey(row: InternalRow): UnsafeRow = keyProjector(row) | ||
|
|
||
| override def commit(store: StateStore): Long = store.commit() | ||
|
||
|
|
||
| override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) | ||
|
|
||
| override def keys(store: StateStore): Iterator[UnsafeRow] = { | ||
| // discard and don't convert values to avoid computation | ||
| store.getRange(None, None).map(_.key) | ||
| } | ||
| } | ||
|
|
||
| class StreamingAggregationStateManagerImplV1( | ||
|
|
@@ -127,17 +140,21 @@ package object state { | |
|
|
||
| override def getStateValueSchema: StructType = inputRowAttributes.toStructType | ||
|
|
||
| override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { | ||
| rowPair.value | ||
| } | ||
|
|
||
| override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { | ||
| store.get(key) | ||
| } | ||
|
|
||
| override def put(store: StateStore, row: UnsafeRow): Unit = { | ||
| store.put(getKey(row), row) | ||
| } | ||
|
|
||
| override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { | ||
| store.iterator() | ||
| } | ||
|
|
||
| override def values(store: StateStore): Iterator[UnsafeRow] = { | ||
| store.iterator().map(_.value) | ||
| } | ||
| } | ||
|
|
||
| class StreamingAggregationStateManagerImplV2( | ||
|
|
@@ -161,15 +178,6 @@ package object state { | |
|
|
||
| override def getStateValueSchema: StructType = valueExpressions.toStructType | ||
|
|
||
| override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { | ||
| val joinedRow = joiner.join(rowPair.key, rowPair.value) | ||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProjector(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
|
|
||
| override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { | ||
| val savedState = store.get(key) | ||
| if (savedState == null) { | ||
|
|
@@ -189,6 +197,23 @@ package object state { | |
| val value = valueProjector(row) | ||
| store.put(key, value) | ||
| } | ||
|
|
||
| override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { | ||
| store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginRow(rowPair))) | ||
| } | ||
|
|
||
| override def values(store: StateStore): Iterator[UnsafeRow] = { | ||
| store.iterator().map(rowPair => restoreOriginRow(rowPair)) | ||
| } | ||
|
|
||
| private def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { | ||
|
||
| val joinedRow = joiner.join(rowPair.key, rowPair.value) | ||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProjector(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -165,6 +165,18 @@ trait WatermarkSupport extends UnaryExecNode { | |
| } | ||
| } | ||
| } | ||
|
|
||
| protected def removeKeysOlderThanWatermark(storeManager: StreamingAggregationStateManager, | ||
|
||
| store: StateStore) | ||
| : Unit = { | ||
| if (watermarkPredicateForKeys.nonEmpty) { | ||
| storeManager.keys(store).foreach { keyRow => | ||
| if (watermarkPredicateForKeys.get.eval(keyRow)) { | ||
| store.remove(keyRow) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| object WatermarkSupport { | ||
|
|
@@ -293,12 +305,12 @@ case class StateStoreSaveExec( | |
| } | ||
| allRemovalsTimeMs += 0 | ||
| commitTimeMs += timeTakenMs { | ||
| store.commit() | ||
| stateManager.commit(store) | ||
| } | ||
| setStoreMetrics(store) | ||
| store.iterator().map { rowPair => | ||
| stateManager.values(store).map { valueRow => | ||
| numOutputRows += 1 | ||
| stateManager.restoreOriginRow(rowPair) | ||
| valueRow | ||
| } | ||
|
|
||
| // Update and output only rows being evicted from the StateStore | ||
|
|
@@ -314,16 +326,16 @@ case class StateStoreSaveExec( | |
| } | ||
|
|
||
| val removalStartTimeNs = System.nanoTime | ||
| val rangeIter = store.getRange(None, None) | ||
| val rangeIter = stateManager.iterator(store) | ||
|
|
||
| new NextIterator[InternalRow] { | ||
| override protected def getNext(): InternalRow = { | ||
| var removedValueRow: InternalRow = null | ||
| while(rangeIter.hasNext && removedValueRow == null) { | ||
| val rowPair = rangeIter.next() | ||
| if (watermarkPredicateForKeys.get.eval(rowPair.key)) { | ||
| store.remove(rowPair.key) | ||
| removedValueRow = stateManager.restoreOriginRow(rowPair) | ||
| stateManager.remove(store, rowPair.key) | ||
| removedValueRow = rowPair.value | ||
| } | ||
| } | ||
| if (removedValueRow == null) { | ||
|
|
@@ -336,7 +348,7 @@ case class StateStoreSaveExec( | |
|
|
||
| override protected def close(): Unit = { | ||
| allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) | ||
| commitTimeMs += timeTakenMs { store.commit() } | ||
| commitTimeMs += timeTakenMs { stateManager.commit(store) } | ||
| setStoreMetrics(store) | ||
| } | ||
| } | ||
|
|
@@ -370,7 +382,7 @@ case class StateStoreSaveExec( | |
|
|
||
| // Remove old aggregates if watermark specified | ||
| allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } | ||
| commitTimeMs += timeTakenMs { store.commit() } | ||
| commitTimeMs += timeTakenMs { stateManager.commit(store) } | ||
| setStoreMetrics(store) | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why is the input typed InternalRow where everything else is UnsafeRow? seems inconsistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getKeywas basically UnsafeProjection in statefulOperator so didn't necessarily require UnsafeRow. I just followed the usage to make it less restrict, but we know, in realityrowwill be always UnsafeRow. So OK to fix if it provides consistency.