Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class IncrementalExecution(
StreamingDeduplicationStrategy :: Nil
}

private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec(
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


override def output: Seq[Attribute] = joinType match {
case _: InnerLike => left.output ++ right.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplic
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.functions._

class DeduplicateSuite extends StateStoreMetricsTest
with BeforeAndAfterAll
with StatefulOperatorTest {
class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {

import testImplicits._

Expand All @@ -44,8 +42,6 @@ class DeduplicateSuite extends StateStoreMetricsTest
AddData(inputData, "a"),
CheckLastBatch("a"),
assertNumStateRows(total = 1, updated = 1),
AssertOnQuery(sq =>
checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))),
AddData(inputData, "a"),
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),
Expand All @@ -63,8 +59,6 @@ class DeduplicateSuite extends StateStoreMetricsTest
AddData(inputData, "a" -> 1),
CheckLastBatch("a" -> 1),
assertNumStateRows(total = 1, updated = 1),
AssertOnQuery(sq =>
checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))),
AddData(inputData, "a" -> 2), // Dropped
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ case class RunningCount(count: Long)
case class Result(key: Long, count: Int)

class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
with BeforeAndAfterAll
with StatefulOperatorTest {
with BeforeAndAfterAll {

import testImplicits._
import GroupStateImpl._
Expand Down Expand Up @@ -618,8 +617,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
AddData(inputData, "a"),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec](
sq, Seq("value"))),
AddData(inputData, "a", "b"),
CheckLastBatch(("a", "2"), ("b", "1")),
assertNumStateRows(total = 2, updated = 2),
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.SparkEnv
import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.AllTuples
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
Expand Down Expand Up @@ -444,6 +445,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}
}

val lastExecution = currentStream.lastExecution
if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) {
// Verify if stateful operators have correct metadata and distribution
// This can often catch hard to debug errors when developing stateful operators
lastExecution.executedPlan.collect { case s: StatefulOperator => s }.foreach { s =>
assert(s.stateInfo.map(_.numPartitions).contains(lastExecution.numStateStores))
s.requiredChildDistribution.foreach { d =>
withClue(s"$s specifies incorrect # partitions in requiredChildDistribution $d") {
assert(d.requiredNumPartitions.isDefined)
assert(d.requiredNumPartitions.get >= 1)
if (d != AllTuples) {
assert(d.requiredNumPartitions.get == s.stateInfo.get.numPartitions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also verify that this is equal to the number of partitions in the metadata?

}
}
}
}
}

val (latestBatchData, allData) = sink match {
case s: MemorySink => (s.latestBatchData, s.allData)
case s: MemorySinkV2 => (s.latestBatchData, s.allData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object FailureSingleton {
}

class StreamingAggregationSuite extends StateStoreMetricsTest
with BeforeAndAfterAll with Assertions with StatefulOperatorTest {
with BeforeAndAfterAll with Assertions {

override def afterAll(): Unit = {
super.afterAll()
Expand Down Expand Up @@ -281,8 +281,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
AddData(inputData, 0L, 5L, 5L, 10L),
AdvanceManualClock(10 * 1000),
CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
AssertOnQuery(sq =>
checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))),

// advance clock to 20 seconds, should retain keys >= 10
AddData(inputData, 15L, 15L, 20L),
Expand Down