Skip to content

Commit 30295bf

Browse files
committed
[SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs
## What changes were proposed in this pull request? This PR migrates the MemoryStream to DataSourceV2 APIs. One additional change is in the reported keys in StreamingQueryProgress.durationMs. "getOffset" and "getBatch" replaced with "setOffsetRange" and "getEndOffset" as tracking these make more sense. Unit tests changed accordingly. ## How was this patch tested? Existing unit tests, few updated unit tests. Author: Tathagata Das <[email protected]> Author: Burak Yavuz <[email protected]> Closes #20445 from tdas/SPARK-23092.
1 parent 9841ae0 commit 30295bf

File tree

9 files changed

+171
-134
lines changed

9 files changed

+171
-134
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
package org.apache.spark.sql.execution.streaming
1919

20+
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
21+
2022
/**
2123
* A simple offset for sources that produce a single linear stream of data.
2224
*/
23-
case class LongOffset(offset: Long) extends Offset {
25+
case class LongOffset(offset: Long) extends OffsetV2 {
2426

2527
override val json = offset.toString
2628

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -270,16 +270,17 @@ class MicroBatchExecution(
270270
}
271271
case s: MicroBatchReader =>
272272
updateStatusMessage(s"Getting offsets from $s")
273-
reportTimeTaken("getOffset") {
274-
// Once v1 streaming source execution is gone, we can refactor this away.
275-
// For now, we set the range here to get the source to infer the available end offset,
276-
// get that offset, and then set the range again when we later execute.
277-
s.setOffsetRange(
278-
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
279-
Optional.empty())
280-
281-
(s, Some(s.getEndOffset))
273+
reportTimeTaken("setOffsetRange") {
274+
// Once v1 streaming source execution is gone, we can refactor this away.
275+
// For now, we set the range here to get the source to infer the available end offset,
276+
// get that offset, and then set the range again when we later execute.
277+
s.setOffsetRange(
278+
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
279+
Optional.empty())
282280
}
281+
282+
val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() }
283+
(s, Option(currentOffset))
283284
}.toMap
284285
availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)
285286

@@ -401,10 +402,14 @@ class MicroBatchExecution(
401402
case (reader: MicroBatchReader, available)
402403
if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
403404
val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
405+
val availableV2: OffsetV2 = available match {
406+
case v1: SerializedOffset => reader.deserializeOffset(v1.json)
407+
case v2: OffsetV2 => v2
408+
}
404409
reader.setOffsetRange(
405410
toJava(current),
406-
Optional.of(available.asInstanceOf[OffsetV2]))
407-
logDebug(s"Retrieving data from $reader: $current -> $available")
411+
Optional.of(availableV2))
412+
logDebug(s"Retrieving data from $reader: $current -> $availableV2")
408413
Some(reader ->
409414
new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
410415
case _ => None

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 79 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,23 @@
1717

1818
package org.apache.spark.sql.execution.streaming
1919

20+
import java.{util => ju}
21+
import java.util.Optional
2022
import java.util.concurrent.atomic.AtomicInteger
2123
import javax.annotation.concurrent.GuardedBy
2224

2325
import scala.collection.JavaConverters._
24-
import scala.collection.mutable
2526
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
2627
import scala.util.control.NonFatal
2728

2829
import org.apache.spark.internal.Logging
2930
import org.apache.spark.sql._
3031
import org.apache.spark.sql.catalyst.encoders.encoderFor
31-
import org.apache.spark.sql.catalyst.expressions.Attribute
32-
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
32+
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
33+
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
3334
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
34-
import org.apache.spark.sql.execution.SQLExecution
35+
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
36+
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
3537
import org.apache.spark.sql.streaming.OutputMode
3638
import org.apache.spark.sql.types.StructType
3739
import org.apache.spark.util.Utils
@@ -51,30 +53,35 @@ object MemoryStream {
5153
* available.
5254
*/
5355
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
54-
extends Source with Logging {
56+
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
5557
protected val encoder = encoderFor[A]
56-
protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession)
58+
private val attributes = encoder.schema.toAttributes
59+
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
5760
protected val output = logicalPlan.output
5861

5962
/**
6063
* All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
6164
* Stored in a ListBuffer to facilitate removing committed batches.
6265
*/
6366
@GuardedBy("this")
64-
protected val batches = new ListBuffer[Dataset[A]]
67+
protected val batches = new ListBuffer[Array[UnsafeRow]]
6568

6669
@GuardedBy("this")
6770
protected var currentOffset: LongOffset = new LongOffset(-1)
6871

72+
@GuardedBy("this")
73+
private var startOffset = new LongOffset(-1)
74+
75+
@GuardedBy("this")
76+
private var endOffset = new LongOffset(-1)
77+
6978
/**
7079
* Last offset that was discarded, or -1 if no commits have occurred. Note that the value
7180
* -1 is used in calculations below and isn't just an arbitrary constant.
7281
*/
7382
@GuardedBy("this")
7483
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
7584

76-
def schema: StructType = encoder.schema
77-
7885
def toDS(): Dataset[A] = {
7986
Dataset(sqlContext.sparkSession, logicalPlan)
8087
}
@@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
8895
}
8996

9097
def addData(data: TraversableOnce[A]): Offset = {
91-
val encoded = data.toVector.map(d => encoder.toRow(d).copy())
92-
val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true)
93-
val ds = Dataset[A](sqlContext.sparkSession, plan)
94-
logDebug(s"Adding ds: $ds")
98+
val objects = data.toSeq
99+
val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
100+
logDebug(s"Adding: $objects")
95101
this.synchronized {
96102
currentOffset = currentOffset + 1
97-
batches += ds
103+
batches += rows
98104
currentOffset
99105
}
100106
}
101107

102108
override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"
103109

104-
override def getOffset: Option[Offset] = synchronized {
105-
if (currentOffset.offset == -1) {
106-
None
107-
} else {
108-
Some(currentOffset)
110+
override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
111+
synchronized {
112+
startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset]
113+
endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset]
109114
}
110115
}
111116

112-
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
113-
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
114-
val startOrdinal =
115-
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
116-
val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
117-
118-
// Internal buffer only holds the batches after lastCommittedOffset.
119-
val newBlocks = synchronized {
120-
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
121-
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
122-
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
123-
batches.slice(sliceStart, sliceEnd)
124-
}
117+
override def readSchema(): StructType = encoder.schema
125118

126-
if (newBlocks.isEmpty) {
127-
return sqlContext.internalCreateDataFrame(
128-
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
129-
}
119+
override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
120+
121+
override def getStartOffset: OffsetV2 = synchronized {
122+
if (startOffset.offset == -1) null else startOffset
123+
}
130124

131-
logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))
125+
override def getEndOffset: OffsetV2 = synchronized {
126+
if (endOffset.offset == -1) null else endOffset
127+
}
132128

133-
newBlocks
134-
.map(_.toDF())
135-
.reduceOption(_ union _)
136-
.getOrElse {
137-
sys.error("No data selected!")
129+
override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
130+
synchronized {
131+
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
132+
val startOrdinal = startOffset.offset.toInt + 1
133+
val endOrdinal = endOffset.offset.toInt + 1
134+
135+
// Internal buffer only holds the batches after lastCommittedOffset.
136+
val newBlocks = synchronized {
137+
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
138+
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
139+
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
140+
batches.slice(sliceStart, sliceEnd)
138141
}
142+
143+
logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
144+
145+
newBlocks.map { block =>
146+
new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]]
147+
}.asJava
148+
}
139149
}
140150

141151
private def generateDebugString(
142-
blocks: TraversableOnce[Dataset[A]],
152+
rows: Seq[UnsafeRow],
143153
startOrdinal: Int,
144154
endOrdinal: Int): String = {
145-
val originalUnsupportedCheck =
146-
sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck")
147-
try {
148-
sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false")
149-
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
150-
s"${blocks.flatMap(_.collect()).mkString(", ")}"
151-
} finally {
152-
sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck)
153-
}
155+
val fromRow = encoder.resolveAndBind().fromRow _
156+
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
157+
s"${rows.map(row => fromRow(row)).mkString(", ")}"
154158
}
155159

156-
override def commit(end: Offset): Unit = synchronized {
160+
override def commit(end: OffsetV2): Unit = synchronized {
157161
def check(newOffset: LongOffset): Unit = {
158162
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
159163

@@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
176180

177181
def reset(): Unit = synchronized {
178182
batches.clear()
183+
startOffset = LongOffset(-1)
184+
endOffset = LongOffset(-1)
179185
currentOffset = new LongOffset(-1)
180186
lastOffsetCommitted = new LongOffset(-1)
181187
}
182188
}
183189

190+
191+
class MemoryStreamDataReaderFactory(records: Array[UnsafeRow])
192+
extends DataReaderFactory[UnsafeRow] {
193+
override def createDataReader(): DataReader[UnsafeRow] = {
194+
new DataReader[UnsafeRow] {
195+
private var currentIndex = -1
196+
197+
override def next(): Boolean = {
198+
// Return true as long as the new index is in the array.
199+
currentIndex += 1
200+
currentIndex < records.length
201+
}
202+
203+
override def get(): UnsafeRow = records(currentIndex)
204+
205+
override def close(): Unit = {}
206+
}
207+
}
208+
}
209+
184210
/**
185211
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
186212
* tests and does not provide durability.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactor
151151
}
152152

153153
class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
154-
var currentIndex = -1
154+
private var currentIndex = -1
155155

156156
override def next(): Boolean = {
157157
// Return true as long as the new index is in the seq.

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,49 +46,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
4646
.foreach(new TestForeachWriter())
4747
.start()
4848

49-
// -- batch 0 ---------------------------------------
50-
input.addData(1, 2, 3, 4)
51-
query.processAllAvailable()
49+
def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = {
50+
import ForeachSinkSuite._
5251

53-
var expectedEventsForPartition0 = Seq(
54-
ForeachSinkSuite.Open(partition = 0, version = 0),
55-
ForeachSinkSuite.Process(value = 2),
56-
ForeachSinkSuite.Process(value = 3),
57-
ForeachSinkSuite.Close(None)
58-
)
59-
var expectedEventsForPartition1 = Seq(
60-
ForeachSinkSuite.Open(partition = 1, version = 0),
61-
ForeachSinkSuite.Process(value = 1),
62-
ForeachSinkSuite.Process(value = 4),
63-
ForeachSinkSuite.Close(None)
64-
)
52+
val events = ForeachSinkSuite.allEvents()
53+
assert(events.size === 2) // one seq of events for each of the 2 partitions
6554

66-
var allEvents = ForeachSinkSuite.allEvents()
67-
assert(allEvents.size === 2)
68-
assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
55+
// Verify both seq of events have an Open event as the first event
56+
assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion)))
57+
58+
// Verify all the Process event correspond to the expected data
59+
val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]]))
60+
assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet)
61+
62+
// Verify both seq of events have a Close event as the last event
63+
assert(events.map(_.last).toSet === Set(Close(None), Close(None)))
64+
}
6965

66+
// -- batch 0 ---------------------------------------
7067
ForeachSinkSuite.clear()
68+
input.addData(1, 2, 3, 4)
69+
query.processAllAvailable()
70+
verifyOutput(expectedVersion = 0, expectedData = 1 to 4)
7171

7272
// -- batch 1 ---------------------------------------
73+
ForeachSinkSuite.clear()
7374
input.addData(5, 6, 7, 8)
7475
query.processAllAvailable()
75-
76-
expectedEventsForPartition0 = Seq(
77-
ForeachSinkSuite.Open(partition = 0, version = 1),
78-
ForeachSinkSuite.Process(value = 5),
79-
ForeachSinkSuite.Process(value = 7),
80-
ForeachSinkSuite.Close(None)
81-
)
82-
expectedEventsForPartition1 = Seq(
83-
ForeachSinkSuite.Open(partition = 1, version = 1),
84-
ForeachSinkSuite.Process(value = 6),
85-
ForeachSinkSuite.Process(value = 8),
86-
ForeachSinkSuite.Close(None)
87-
)
88-
89-
allEvents = ForeachSinkSuite.allEvents()
90-
assert(allEvents.size === 2)
91-
assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
76+
verifyOutput(expectedVersion = 1, expectedData = 5 to 8)
9277

9378
query.stop()
9479
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,16 +492,16 @@ class StreamSuite extends StreamTest {
492492

493493
val explainWithoutExtended = q.explainInternal(false)
494494
// `extended = false` only displays the physical plan.
495-
assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0)
496-
assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1)
495+
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0)
496+
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1)
497497
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
498498
assert(explainWithoutExtended.contains("StateStoreRestore"))
499499

500500
val explainWithExtended = q.explainInternal(true)
501501
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
502502
// plan.
503-
assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3)
504-
assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1)
503+
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3)
504+
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1)
505505
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
506506
assert(explainWithExtended.contains("StateStoreRestore"))
507507
} finally {

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
120120
case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
121121
override def toString: String = s"AddData to $source: ${data.mkString(",")}"
122122

123-
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
123+
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
124124
(source, source.addData(data))
125125
}
126126
}

0 commit comments

Comments
 (0)