1717
1818package org .apache .spark .sql .execution .streaming
1919
20+ import java .{util => ju }
21+ import java .util .Optional
2022import java .util .concurrent .atomic .AtomicInteger
2123import javax .annotation .concurrent .GuardedBy
2224
2325import scala .collection .JavaConverters ._
24- import scala .collection .mutable
2526import scala .collection .mutable .{ArrayBuffer , ListBuffer }
2627import scala .util .control .NonFatal
2728
2829import org .apache .spark .internal .Logging
2930import org .apache .spark .sql ._
3031import org .apache .spark .sql .catalyst .encoders .encoderFor
31- import org .apache .spark .sql .catalyst .expressions .Attribute
32- import org .apache .spark .sql .catalyst .plans .logical .{LeafNode , LocalRelation , Statistics }
32+ import org .apache .spark .sql .catalyst .expressions .{ Attribute , UnsafeRow }
33+ import org .apache .spark .sql .catalyst .plans .logical .{LeafNode , Statistics }
3334import org .apache .spark .sql .catalyst .streaming .InternalOutputModes ._
34- import org .apache .spark .sql .execution .SQLExecution
35+ import org .apache .spark .sql .sources .v2 .reader .{DataReader , DataReaderFactory , SupportsScanUnsafeRow }
36+ import org .apache .spark .sql .sources .v2 .reader .streaming .{MicroBatchReader , Offset => OffsetV2 }
3537import org .apache .spark .sql .streaming .OutputMode
3638import org .apache .spark .sql .types .StructType
3739import org .apache .spark .util .Utils
@@ -51,30 +53,35 @@ object MemoryStream {
5153 * available.
5254 */
5355case class MemoryStream [A : Encoder ](id : Int , sqlContext : SQLContext )
54- extends Source with Logging {
56+ extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
5557 protected val encoder = encoderFor[A ]
56- protected val logicalPlan = StreamingExecutionRelation (this , sqlContext.sparkSession)
58+ private val attributes = encoder.schema.toAttributes
59+ protected val logicalPlan = StreamingExecutionRelation (this , attributes)(sqlContext.sparkSession)
5760 protected val output = logicalPlan.output
5861
5962 /**
6063 * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
6164 * Stored in a ListBuffer to facilitate removing committed batches.
6265 */
6366 @ GuardedBy (" this" )
64- protected val batches = new ListBuffer [Dataset [ A ]]
67+ protected val batches = new ListBuffer [Array [ UnsafeRow ]]
6568
6669 @ GuardedBy (" this" )
6770 protected var currentOffset : LongOffset = new LongOffset (- 1 )
6871
72+ @ GuardedBy (" this" )
73+ private var startOffset = new LongOffset (- 1 )
74+
75+ @ GuardedBy (" this" )
76+ private var endOffset = new LongOffset (- 1 )
77+
6978 /**
7079 * Last offset that was discarded, or -1 if no commits have occurred. Note that the value
7180 * -1 is used in calculations below and isn't just an arbitrary constant.
7281 */
7382 @ GuardedBy (" this" )
7483 protected var lastOffsetCommitted : LongOffset = new LongOffset (- 1 )
7584
76- def schema : StructType = encoder.schema
77-
7885 def toDS (): Dataset [A ] = {
7986 Dataset (sqlContext.sparkSession, logicalPlan)
8087 }
@@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
8895 }
8996
9097 def addData (data : TraversableOnce [A ]): Offset = {
91- val encoded = data.toVector.map(d => encoder.toRow(d).copy())
92- val plan = new LocalRelation (schema.toAttributes, encoded, isStreaming = true )
93- val ds = Dataset [A ](sqlContext.sparkSession, plan)
94- logDebug(s " Adding ds: $ds" )
98+ val objects = data.toSeq
99+ val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf [UnsafeRow ]).toArray
100+ logDebug(s " Adding: $objects" )
95101 this .synchronized {
96102 currentOffset = currentOffset + 1
97- batches += ds
103+ batches += rows
98104 currentOffset
99105 }
100106 }
101107
102108 override def toString : String = s " MemoryStream[ ${Utils .truncatedString(output, " ," )}] "
103109
104- override def getOffset : Option [Offset ] = synchronized {
105- if (currentOffset.offset == - 1 ) {
106- None
107- } else {
108- Some (currentOffset)
110+ override def setOffsetRange (start : Optional [OffsetV2 ], end : Optional [OffsetV2 ]): Unit = {
111+ synchronized {
112+ startOffset = start.orElse(LongOffset (- 1 )).asInstanceOf [LongOffset ]
113+ endOffset = end.orElse(currentOffset).asInstanceOf [LongOffset ]
109114 }
110115 }
111116
112- override def getBatch (start : Option [Offset ], end : Offset ): DataFrame = {
113- // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
114- val startOrdinal =
115- start.flatMap(LongOffset .convert).getOrElse(LongOffset (- 1 )).offset.toInt + 1
116- val endOrdinal = LongOffset .convert(end).getOrElse(LongOffset (- 1 )).offset.toInt + 1
117-
118- // Internal buffer only holds the batches after lastCommittedOffset.
119- val newBlocks = synchronized {
120- val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
121- val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
122- assert(sliceStart <= sliceEnd, s " sliceStart: $sliceStart sliceEnd: $sliceEnd" )
123- batches.slice(sliceStart, sliceEnd)
124- }
117+ override def readSchema (): StructType = encoder.schema
125118
126- if (newBlocks.isEmpty) {
127- return sqlContext.internalCreateDataFrame(
128- sqlContext.sparkContext.emptyRDD, schema, isStreaming = true )
129- }
119+ override def deserializeOffset (json : String ): OffsetV2 = LongOffset (json.toLong)
120+
121+ override def getStartOffset : OffsetV2 = synchronized {
122+ if (startOffset.offset == - 1 ) null else startOffset
123+ }
130124
131- logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))
125+ override def getEndOffset : OffsetV2 = synchronized {
126+ if (endOffset.offset == - 1 ) null else endOffset
127+ }
132128
133- newBlocks
134- .map(_.toDF())
135- .reduceOption(_ union _)
136- .getOrElse {
137- sys.error(" No data selected!" )
129+ override def createUnsafeRowReaderFactories (): ju.List [DataReaderFactory [UnsafeRow ]] = {
130+ synchronized {
131+ // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
132+ val startOrdinal = startOffset.offset.toInt + 1
133+ val endOrdinal = endOffset.offset.toInt + 1
134+
135+ // Internal buffer only holds the batches after lastCommittedOffset.
136+ val newBlocks = synchronized {
137+ val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
138+ val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
139+ assert(sliceStart <= sliceEnd, s " sliceStart: $sliceStart sliceEnd: $sliceEnd" )
140+ batches.slice(sliceStart, sliceEnd)
138141 }
142+
143+ logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
144+
145+ newBlocks.map { block =>
146+ new MemoryStreamDataReaderFactory (block).asInstanceOf [DataReaderFactory [UnsafeRow ]]
147+ }.asJava
148+ }
139149 }
140150
141151 private def generateDebugString (
142- blocks : TraversableOnce [ Dataset [ A ] ],
152+ rows : Seq [ UnsafeRow ],
143153 startOrdinal : Int ,
144154 endOrdinal : Int ): String = {
145- val originalUnsupportedCheck =
146- sqlContext.getConf(" spark.sql.streaming.unsupportedOperationCheck" )
147- try {
148- sqlContext.setConf(" spark.sql.streaming.unsupportedOperationCheck" , " false" )
149- s " MemoryBatch [ $startOrdinal, $endOrdinal]: " +
150- s " ${blocks.flatMap(_.collect()).mkString(" , " )}"
151- } finally {
152- sqlContext.setConf(" spark.sql.streaming.unsupportedOperationCheck" , originalUnsupportedCheck)
153- }
155+ val fromRow = encoder.resolveAndBind().fromRow _
156+ s " MemoryBatch [ $startOrdinal, $endOrdinal]: " +
157+ s " ${rows.map(row => fromRow(row)).mkString(" , " )}"
154158 }
155159
156- override def commit (end : Offset ): Unit = synchronized {
160+ override def commit (end : OffsetV2 ): Unit = synchronized {
157161 def check (newOffset : LongOffset ): Unit = {
158162 val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
159163
@@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
176180
177181 def reset (): Unit = synchronized {
178182 batches.clear()
183+ startOffset = LongOffset (- 1 )
184+ endOffset = LongOffset (- 1 )
179185 currentOffset = new LongOffset (- 1 )
180186 lastOffsetCommitted = new LongOffset (- 1 )
181187 }
182188}
183189
190+
191+ class MemoryStreamDataReaderFactory (records : Array [UnsafeRow ])
192+ extends DataReaderFactory [UnsafeRow ] {
193+ override def createDataReader (): DataReader [UnsafeRow ] = {
194+ new DataReader [UnsafeRow ] {
195+ private var currentIndex = - 1
196+
197+ override def next (): Boolean = {
198+ // Return true as long as the new index is in the array.
199+ currentIndex += 1
200+ currentIndex < records.length
201+ }
202+
203+ override def get (): UnsafeRow = records(currentIndex)
204+
205+ override def close (): Unit = {}
206+ }
207+ }
208+ }
209+
184210/**
185211 * A sink that stores the results in memory. This [[Sink ]] is primarily intended for use in unit
186212 * tests and does not provide durability.
0 commit comments