Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP Address a part of review comments from @tdas
* TODO list
  * add docs
  • Loading branch information
HeartSaVioR committed Aug 1, 2018
commit b4a3807631cc8e12df367eeca554749fdd81a5ef
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ package object state {
sealed trait StreamingAggregationStateManager extends Serializable {
def getKey(row: InternalRow): UnsafeRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why is the input typed InternalRow where everything else is UnsafeRow? seems inconsistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getKey was basically UnsafeProjection in statefulOperator so didn't necessarily require UnsafeRow. I just followed the usage to make it less restrict, but we know, in reality row will be always UnsafeRow. So OK to fix if it provides consistency.

def getStateValueSchema: StructType
def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow
def get(store: StateStore, key: UnsafeRow): UnsafeRow
def put(store: StateStore, row: UnsafeRow): Unit
def commit(store: StateStore): Long
def remove(store: StateStore, key: UnsafeRow): Unit
def iterator(store: StateStore): Iterator[UnsafeRowPair]
def keys(store: StateStore): Iterator[UnsafeRow]
def values(store: StateStore): Iterator[UnsafeRow]
}

object StreamingAggregationStateManager extends Logging {
Expand All @@ -118,6 +122,15 @@ package object state {
GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes)

def getKey(row: InternalRow): UnsafeRow = keyProjector(row)

override def commit(store: StateStore): Long = store.commit()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really does not need to be in this interface as this is not customized and is unlikely to be ever customized across implementations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is actually based on your review comment: always use state manager and don't directly access state store whenever possible. If your suggestion only applies to operations I can remove commit() from this interface.


override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key)

override def keys(store: StateStore): Iterator[UnsafeRow] = {
// discard and don't convert values to avoid computation
store.getRange(None, None).map(_.key)
}
}

class StreamingAggregationStateManagerImplV1(
Expand All @@ -127,17 +140,21 @@ package object state {

override def getStateValueSchema: StructType = inputRowAttributes.toStructType

override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
rowPair.value
}

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
store.get(key)
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
store.put(getKey(row), row)
}

override def iterator(store: StateStore): Iterator[UnsafeRowPair] = {
store.iterator()
}

override def values(store: StateStore): Iterator[UnsafeRow] = {
store.iterator().map(_.value)
}
}

class StreamingAggregationStateManagerImplV2(
Expand All @@ -161,15 +178,6 @@ package object state {

override def getStateValueSchema: StructType = valueExpressions.toStructType

override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
val joinedRow = joiner.join(rowPair.key, rowPair.value)
if (needToProjectToRestoreValue) {
restoreValueProjector(joinedRow)
} else {
joinedRow
}
}

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
val savedState = store.get(key)
if (savedState == null) {
Expand All @@ -189,6 +197,23 @@ package object state {
val value = valueProjector(row)
store.put(key, value)
}

override def iterator(store: StateStore): Iterator[UnsafeRowPair] = {
store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginRow(rowPair)))
}

override def values(store: StateStore): Iterator[UnsafeRow] = {
store.iterator().map(rowPair => restoreOriginRow(rowPair))
}

private def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to restoreOriginalRow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will rename.

val joinedRow = joiner.join(rowPair.key, rowPair.value)
if (needToProjectToRestoreValue) {
restoreValueProjector(joinedRow)
} else {
joinedRow
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,18 @@ trait WatermarkSupport extends UnaryExecNode {
}
}
}

protected def removeKeysOlderThanWatermark(storeManager: StreamingAggregationStateManager,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorrect indent of parameters

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually... where is this used? This does not seem to be used anywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed spot. It will be called from Update mode of StateStoreSaveExec. Will address.

store: StateStore)
: Unit = {
if (watermarkPredicateForKeys.nonEmpty) {
storeManager.keys(store).foreach { keyRow =>
if (watermarkPredicateForKeys.get.eval(keyRow)) {
store.remove(keyRow)
}
}
}
}
}

object WatermarkSupport {
Expand Down Expand Up @@ -293,12 +305,12 @@ case class StateStoreSaveExec(
}
allRemovalsTimeMs += 0
commitTimeMs += timeTakenMs {
store.commit()
stateManager.commit(store)
}
setStoreMetrics(store)
store.iterator().map { rowPair =>
stateManager.values(store).map { valueRow =>
numOutputRows += 1
stateManager.restoreOriginRow(rowPair)
valueRow
}

// Update and output only rows being evicted from the StateStore
Expand All @@ -314,16 +326,16 @@ case class StateStoreSaveExec(
}

val removalStartTimeNs = System.nanoTime
val rangeIter = store.getRange(None, None)
val rangeIter = stateManager.iterator(store)

new NextIterator[InternalRow] {
override protected def getNext(): InternalRow = {
var removedValueRow: InternalRow = null
while(rangeIter.hasNext && removedValueRow == null) {
val rowPair = rangeIter.next()
if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
store.remove(rowPair.key)
removedValueRow = stateManager.restoreOriginRow(rowPair)
stateManager.remove(store, rowPair.key)
removedValueRow = rowPair.value
}
}
if (removedValueRow == null) {
Expand All @@ -336,7 +348,7 @@ case class StateStoreSaveExec(

override protected def close(): Unit = {
allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
commitTimeMs += timeTakenMs { store.commit() }
commitTimeMs += timeTakenMs { stateManager.commit(store) }
setStoreMetrics(store)
}
}
Expand Down Expand Up @@ -370,7 +382,7 @@ case class StateStoreSaveExec(

// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
commitTimeMs += timeTakenMs { stateManager.commit(store) }
setStoreMetrics(store)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class StreamingAggregationStateManagerSuite extends StreamTest {
stateManager.put(memoryStateStore, inputRow)

assert(memoryStateStore.iterator().size === 1)
assert(stateManager.iterator(memoryStateStore).size === memoryStateStore.iterator().size)

val keyRow = stateManager.getKey(inputRow)
assert(keyRow === expectedStateKey)
Expand All @@ -111,7 +112,15 @@ class StreamingAggregationStateManagerSuite extends StreamTest {
val pair = memoryStateStore.iterator().next()
assert(pair.key === keyRow)
assert(pair.value === expectedStateValue)
assert(stateManager.restoreOriginRow(pair) === inputRow)

// iterate with state manager and see whether original rows are returned as values
val pairFromStateManager = stateManager.iterator(memoryStateStore).next()
assert(pairFromStateManager.key === keyRow)
assert(pairFromStateManager.value === inputRow)

// following as keys and values
assert(stateManager.keys(memoryStateStore).next() === keyRow)
assert(stateManager.values(memoryStateStore).next() === inputRow)

// verify the stored value once again via get
assert(memoryStateStore.get(keyRow) === expectedStateValue)
Expand Down