Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}

object DataSourceV2Strategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
Expand All @@ -32,6 +33,9 @@ object DataSourceV2Strategy extends Strategy {
case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil

case WriteToContinuousDataSource(writer, query) =>
WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil

case _ => Nil
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,40 +65,21 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
s"The input RDD has ${messages.length} partitions.")

try {
val runTask = writer match {
// This case means that we're doing continuous processing. In microbatch streaming, the
// StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch.
case w: StreamWriter =>
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))

(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.runContinuous(writeTask, context, iter)
case _ =>
(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator)
}

sparkContext.runJob(
rdd,
runTask,
(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator),
rdd.partitions.indices,
(index, message: WriterCommitMessage) => {
messages(index) = message
writer.onDataWriterCommit(message)
}
)

if (!writer.isInstanceOf[StreamWriter]) {
logInfo(s"Data source writer $writer is committing.")
writer.commit(messages)
logInfo(s"Data source writer $writer committed.")
}
logInfo(s"Data source writer $writer is committing.")
writer.commit(messages)
logInfo(s"Data source writer $writer committed.")
} catch {
case _: InterruptedException if writer.isInstanceOf[StreamWriter] =>
// Interruption is how continuous queries are ended, so accept and ignore the exception.
case cause: Throwable =>
logError(s"Data source writer $writer is aborting.")
try {
Expand All @@ -111,8 +92,6 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
}
logError(s"Data source writer $writer aborted.")
cause match {
// Do not wrap interruption exceptions that will be handled by streaming specially.
case _ if StreamExecution.isInterruptionException(cause) => throw cause
// Only wrap non fatal exceptions.
case NonFatal(e) => throw new SparkException("Writing job aborted.", e)
case _ => throw cause
Expand Down Expand Up @@ -168,49 +147,6 @@ object DataWritingSparkTask extends Logging {
logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.")
})
}

def runContinuous(
writeTask: DataWriterFactory[InternalRow],
context: TaskContext,
iter: Iterator[InternalRow]): WriterCommitMessage = {
val epochCoordinator = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
SparkEnv.get)
val currentMsg: WriterCommitMessage = null
var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong

do {
var dataWriter: DataWriter[InternalRow] = null
// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
try {
dataWriter = writeTask.createDataWriter(
context.partitionId(), context.attemptNumber(), currentEpoch)
while (iter.hasNext) {
dataWriter.write(iter.next())
}
logInfo(s"Writer for partition ${context.partitionId()} is committing.")
val msg = dataWriter.commit()
logInfo(s"Writer for partition ${context.partitionId()} committed.")
epochCoordinator.send(
CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
)
currentEpoch += 1
} catch {
case _: InterruptedException =>
// Continuous shutdown always involves an interrupt. Just finish the task.
}
})(catchBlock = {
// If there is an error, abort this writer. We enter this callback in the middle of
// rethrowing an exception, so runContinuous will stop executing at this point.
logError(s"Writer for partition ${context.partitionId()} is aborting.")
if (dataWriter != null) dataWriter.abort()
logError(s"Writer for partition ${context.partitionId()} aborted.")
})
} while (!context.isInterrupted())

currentMsg
}
}

class InternalRowDataWriterFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class ContinuousExecution(
triggerLogicalPlan.schema,
outputMode,
new DataSourceOptions(extraOptions.asJava))
val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan)
val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan)

val reader = withSink.collect {
case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter

/**
* The logical plan for writing data in a continuous stream.
*/
case class WriteToContinuousDataSource(
writer: StreamWriter, query: LogicalPlan) extends LogicalPlan {
override def children: Seq[LogicalPlan] = Seq(query)
override def output: Seq[Attribute] = Nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous

import scala.util.control.NonFatal

import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory}
import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.util.Utils

/**
* The physical plan for writing data into a continuous processing [[StreamWriter]].
*/
case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan)
extends SparkPlan with Logging {
override def children: Seq[SparkPlan] = Seq(query)
override def output: Seq[Attribute] = Nil

override protected def doExecute(): RDD[InternalRow] = {
val writerFactory = writer match {
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
}

val rdd = query.execute()

logInfo(s"Start processing data source writer: $writer. " +
s"The input RDD has ${rdd.getNumPartitions} partitions.")
// Let the epoch coordinator know how many partitions the write RDD has.
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))

try {
// Force the RDD to run so continuous processing starts; no data is actually being collected
// to the driver, as ContinuousWriteRDD outputs nothing.
sparkContext.runJob(
rdd,
(context: TaskContext, iter: Iterator[InternalRow]) =>
WriteToContinuousDataSourceExec.run(writerFactory, context, iter),
rdd.partitions.indices)
} catch {
case _: InterruptedException =>
// Interruption is how continuous queries are ended, so accept and ignore the exception.
case cause: Throwable =>
cause match {
// Do not wrap interruption exceptions that will be handled by streaming specially.
case _ if StreamExecution.isInterruptionException(cause) => throw cause
// Only wrap non fatal exceptions.
case NonFatal(e) => throw new SparkException("Writing job aborted.", e)
case _ => throw cause
}
}

sparkContext.emptyRDD
}
}

object WriteToContinuousDataSourceExec extends Logging {
def run(
writeTask: DataWriterFactory[InternalRow],
context: TaskContext,
iter: Iterator[InternalRow]): Unit = {
val epochCoordinator = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
SparkEnv.get)
var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong

do {
var dataWriter: DataWriter[InternalRow] = null
// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
try {
dataWriter = writeTask.createDataWriter(
context.partitionId(), context.attemptNumber(), currentEpoch)
while (iter.hasNext) {
dataWriter.write(iter.next())
}
logInfo(s"Writer for partition ${context.partitionId()} is committing.")
val msg = dataWriter.commit()
logInfo(s"Writer for partition ${context.partitionId()} committed.")
epochCoordinator.send(
CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
)
currentEpoch += 1
} catch {
case _: InterruptedException =>
// Continuous shutdown always involves an interrupt. Just finish the task.
}
})(catchBlock = {
// If there is an error, abort this writer. We enter this callback in the middle of
// rethrowing an exception, so runContinuous will stop executing at this point.
logError(s"Writer for partition ${context.partitionId()} is aborting.")
if (dataWriter != null) dataWriter.abort()
logError(s"Writer for partition ${context.partitionId()} aborted.")
})
} while (!context.isInterrupted())
}
}