-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs #20445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
7c09b37
78c50f8
2777b5b
50a541b
fd61724
7a0b564
a81c2ec
1a4f410
083e93c
a817c8d
35b8854
e66d809
5adf1fe
478ad17
6389d80
3f50f33
c713048
1204755
f0ce5df
c3508e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,21 +17,23 @@ | |
|
|
||
| package org.apache.spark.sql.execution.streaming | ||
|
|
||
| import java.{util => ju} | ||
| import java.util.Optional | ||
| import java.util.concurrent.atomic.AtomicInteger | ||
| import javax.annotation.concurrent.GuardedBy | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable | ||
| import scala.collection.mutable.{ArrayBuffer, ListBuffer} | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.catalyst.encoders.encoderFor | ||
| import org.apache.spark.sql.catalyst.expressions.Attribute | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} | ||
| import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ | ||
| import org.apache.spark.sql.execution.SQLExecution | ||
| import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} | ||
| import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} | ||
| import org.apache.spark.sql.streaming.OutputMode | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.util.Utils | ||
|
|
@@ -51,30 +53,35 @@ object MemoryStream { | |
| * available. | ||
| */ | ||
| case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | ||
| extends Source with Logging { | ||
| extends MicroBatchReader with SupportsScanUnsafeRow with Logging { | ||
| protected val encoder = encoderFor[A] | ||
| protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession) | ||
| private val attributes = encoder.schema.toAttributes | ||
| protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) | ||
| protected val output = logicalPlan.output | ||
|
|
||
| /** | ||
| * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive. | ||
| * Stored in a ListBuffer to facilitate removing committed batches. | ||
| */ | ||
| @GuardedBy("this") | ||
| protected val batches = new ListBuffer[Dataset[A]] | ||
| protected val batches = new ListBuffer[Array[UnsafeRow]] | ||
|
|
||
| @GuardedBy("this") | ||
| protected var currentOffset: LongOffset = new LongOffset(-1) | ||
|
|
||
| @GuardedBy("this") | ||
| private var startOffset = new LongOffset(-1) | ||
|
|
||
| @GuardedBy("this") | ||
| private var endOffset = new LongOffset(-1) | ||
|
|
||
| /** | ||
| * Last offset that was discarded, or -1 if no commits have occurred. Note that the value | ||
| * -1 is used in calculations below and isn't just an arbitrary constant. | ||
| */ | ||
| @GuardedBy("this") | ||
| protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) | ||
|
|
||
| def schema: StructType = encoder.schema | ||
|
|
||
| def toDS(): Dataset[A] = { | ||
| Dataset(sqlContext.sparkSession, logicalPlan) | ||
| } | ||
|
|
@@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | |
| } | ||
|
|
||
| def addData(data: TraversableOnce[A]): Offset = { | ||
| val encoded = data.toVector.map(d => encoder.toRow(d).copy()) | ||
| val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true) | ||
| val ds = Dataset[A](sqlContext.sparkSession, plan) | ||
| logDebug(s"Adding ds: $ds") | ||
| val objects = data.toSeq | ||
| val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray | ||
| logDebug(s"Adding: $objects") | ||
| this.synchronized { | ||
| currentOffset = currentOffset + 1 | ||
| batches += ds | ||
| batches += rows | ||
| currentOffset | ||
| } | ||
| } | ||
|
|
||
| override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" | ||
|
|
||
| override def getOffset: Option[Offset] = synchronized { | ||
| if (currentOffset.offset == -1) { | ||
| None | ||
| } else { | ||
| Some(currentOffset) | ||
| override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { | ||
| synchronized { | ||
| startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] | ||
| endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] | ||
| } | ||
| } | ||
|
|
||
| override def getBatch(start: Option[Offset], end: Offset): DataFrame = { | ||
| // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) | ||
| val startOrdinal = | ||
| start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 | ||
| val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 | ||
|
|
||
| // Internal buffer only holds the batches after lastCommittedOffset. | ||
| val newBlocks = synchronized { | ||
| val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 | ||
| val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 | ||
| assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") | ||
| batches.slice(sliceStart, sliceEnd) | ||
| } | ||
| override def readSchema(): StructType = encoder.schema | ||
|
|
||
| if (newBlocks.isEmpty) { | ||
| return sqlContext.internalCreateDataFrame( | ||
| sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) | ||
| } | ||
| override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) | ||
|
|
||
| override def getStartOffset: OffsetV2 = synchronized { | ||
| if (startOffset.offset == -1) null else startOffset | ||
| } | ||
|
|
||
| logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) | ||
| override def getEndOffset: OffsetV2 = synchronized { | ||
| if (endOffset.offset == -1) null else endOffset | ||
| } | ||
|
|
||
| newBlocks | ||
| .map(_.toDF()) | ||
| .reduceOption(_ union _) | ||
| .getOrElse { | ||
| sys.error("No data selected!") | ||
| override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { | ||
| synchronized { | ||
| // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) | ||
| val startOrdinal = startOffset.offset.toInt + 1 | ||
| val endOrdinal = endOffset.offset.toInt + 1 | ||
|
|
||
| // Internal buffer only holds the batches after lastCommittedOffset. | ||
| val newBlocks = synchronized { | ||
| val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 | ||
| val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 | ||
| assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") | ||
| batches.slice(sliceStart, sliceEnd) | ||
| } | ||
|
|
||
| logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) | ||
|
|
||
| newBlocks.map { block => | ||
| new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] | ||
| }.asJava | ||
| } | ||
| } | ||
|
|
||
| private def generateDebugString( | ||
| blocks: TraversableOnce[Dataset[A]], | ||
| blocks: Seq[UnsafeRow], | ||
|
||
| startOrdinal: Int, | ||
| endOrdinal: Int): String = { | ||
| val originalUnsupportedCheck = | ||
| sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck") | ||
| try { | ||
| sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false") | ||
| s"MemoryBatch [$startOrdinal, $endOrdinal]: " + | ||
| s"${blocks.flatMap(_.collect()).mkString(", ")}" | ||
| } finally { | ||
| sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck) | ||
| } | ||
| val fromRow = encoder.resolveAndBind().fromRow _ | ||
| s"MemoryBatch [$startOrdinal, $endOrdinal]: " + | ||
| s"${blocks.map(row => fromRow(row)).mkString(", ")}" | ||
| } | ||
|
|
||
| override def commit(end: Offset): Unit = synchronized { | ||
| override def commit(end: OffsetV2): Unit = synchronized { | ||
| def check(newOffset: LongOffset): Unit = { | ||
| val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt | ||
|
|
||
|
|
@@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | |
|
|
||
| def reset(): Unit = synchronized { | ||
| batches.clear() | ||
| startOffset = LongOffset(-1) | ||
| endOffset = LongOffset(-1) | ||
| currentOffset = new LongOffset(-1) | ||
| lastOffsetCommitted = new LongOffset(-1) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) | ||
| extends DataReaderFactory[UnsafeRow] { | ||
| override def createDataReader(): DataReader[UnsafeRow] = { | ||
| new DataReader[UnsafeRow] { | ||
| private var currentIndex = -1 | ||
|
|
||
| override def next(): Boolean = { | ||
| // Return true as long as the new index is in the array. | ||
| currentIndex += 1 | ||
| currentIndex < records.length | ||
| } | ||
|
|
||
| override def get(): UnsafeRow = records(currentIndex) | ||
|
|
||
| override def close(): Unit = {} | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit | ||
| * tests and does not provide durability. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,49 +46,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf | |
| .foreach(new TestForeachWriter()) | ||
| .start() | ||
|
|
||
| // -- batch 0 --------------------------------------- | ||
| input.addData(1, 2, 3, 4) | ||
| query.processAllAvailable() | ||
| def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { | ||
| import ForeachSinkSuite._ | ||
|
|
||
| var expectedEventsForPartition0 = Seq( | ||
| ForeachSinkSuite.Open(partition = 0, version = 0), | ||
| ForeachSinkSuite.Process(value = 2), | ||
| ForeachSinkSuite.Process(value = 3), | ||
| ForeachSinkSuite.Close(None) | ||
| ) | ||
| var expectedEventsForPartition1 = Seq( | ||
| ForeachSinkSuite.Open(partition = 1, version = 0), | ||
| ForeachSinkSuite.Process(value = 1), | ||
| ForeachSinkSuite.Process(value = 4), | ||
| ForeachSinkSuite.Close(None) | ||
| ) | ||
| val events = ForeachSinkSuite.allEvents() | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test assumed that the output would arrive in specific order after repartitioning, which isnt guaranteed. So I rewrote the test to verify the output in an order-independent way. |
||
| assert(events.size === 2) // one seq of events for each of the 2 partitions | ||
|
|
||
| var allEvents = ForeachSinkSuite.allEvents() | ||
| assert(allEvents.size === 2) | ||
| assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) | ||
| // Verify both seq of events have an Open event as the first event | ||
| assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion))) | ||
|
|
||
| // Verify all the Process event correspond to the expected data | ||
| val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]])) | ||
| assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet) | ||
|
|
||
| // Verify both seq of events have a Close event as the last event | ||
| assert(events.map(_.last).toSet === Set(Close(None), Close(None))) | ||
| } | ||
|
|
||
| // -- batch 0 --------------------------------------- | ||
| ForeachSinkSuite.clear() | ||
| input.addData(1, 2, 3, 4) | ||
| query.processAllAvailable() | ||
| verifyOutput(expectedVersion = 0, expectedData = 1 to 4) | ||
|
|
||
| // -- batch 1 --------------------------------------- | ||
| ForeachSinkSuite.clear() | ||
| input.addData(5, 6, 7, 8) | ||
| query.processAllAvailable() | ||
|
|
||
| expectedEventsForPartition0 = Seq( | ||
| ForeachSinkSuite.Open(partition = 0, version = 1), | ||
| ForeachSinkSuite.Process(value = 5), | ||
| ForeachSinkSuite.Process(value = 7), | ||
| ForeachSinkSuite.Close(None) | ||
| ) | ||
| expectedEventsForPartition1 = Seq( | ||
| ForeachSinkSuite.Open(partition = 1, version = 1), | ||
| ForeachSinkSuite.Process(value = 6), | ||
| ForeachSinkSuite.Process(value = 8), | ||
| ForeachSinkSuite.Close(None) | ||
| ) | ||
|
|
||
| allEvents = ForeachSinkSuite.allEvents() | ||
| assert(allEvents.size === 2) | ||
| assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) | ||
| verifyOutput(expectedVersion = 1, expectedData = 5 to 8) | ||
|
|
||
| query.stop() | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the old metric names don't make much sense anymore, but I worry about changing external-facing behavior as part of an API migration.