Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ object StaticSQLConf {
.intConf
.createWithDefault(1000)

val SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD =
buildStaticConf("spark.sql.shuffleExchange.maxThreadThreshold")
.internal()
.doc("The maximum degree of parallelism for doing preparation of shuffle exchange, " +
"which includes subquery execution, file listing, etc.")
.version("4.0.0")
.intConf
.checkValue(thres => thres > 0 && thres <= 1024, "The threshold must be in (0,1024].")
.createWithDefault(1024)
Copy link
Member

@yaooqinn yaooqinn Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we pick this number? It might create memory pressure on the driver

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shuffle async job is just waiting for other work (subquery expression execution) to finish, which is very light-weighted. The broadcast async job executes a query and collects the result in the driver, which is very heavy. That's why we can give much larger parallelism to the shuffle async jobs. In our benchmark we found this number is reasonably good for TPC.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a correlation with the number of system cores?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, the BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD is also way larger than the driver system cores.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this parameter has anything to do with SPARK-49091 or if it was just caused by SPARK-41914 which the reporter pointed to.

Also cc @wangyum

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: SPARK-49091 is not related


val BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD =
buildStaticConf("spark.sql.broadcastExchange.maxThreadThreshold")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ case class ShuffleQueryStageExec(

def advisoryPartitionSize: Option[Long] = shuffle.advisoryPartitionSize

override protected def doMaterialize(): Future[Any] = shuffle.submitShuffleJob
override protected def doMaterialize(): Future[Any] = shuffle.submitShuffleJob()

override def newReuseInstance(
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
Expand Down Expand Up @@ -240,7 +240,7 @@ case class BroadcastQueryStageExec(
throw SparkException.internalError(s"wrong plan for broadcast stage:\n ${plan.treeString}")
}

override protected def doMaterialize(): Future[Any] = broadcast.submitBroadcastJob
override protected def doMaterialize(): Future[Any] = broadcast.submitBroadcastJob()

override def newReuseInstance(
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.util.control.NonFatal
import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.MDC
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.logical.Statistics
Expand Down Expand Up @@ -61,23 +61,49 @@ trait BroadcastExchangeLike extends Exchange {
*/
def relationFuture: Future[broadcast.Broadcast[Any]]

@transient
private lazy val promise = Promise[Unit]()

@transient
private lazy val scalaFuture: scala.concurrent.Future[Unit] = promise.future

@transient
private lazy val triggerFuture: Future[Any] = {
SQLExecution.withThreadLocalCaptured(session, BroadcastExchangeExec.executionContext) {
try {
// Trigger broadcast preparation which can involve expensive operations like waiting on
// subqueries and file listing.
executeQuery(null)
promise.trySuccess(())
} catch {
case e: Throwable =>
promise.tryFailure(e)
throw e
}
}
}

protected def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]

/**
* The asynchronous job that materializes the broadcast. It's used for registering callbacks on
* `relationFuture`. Note that calling this method may not start the execution of broadcast job.
* It also does the preparations work, such as waiting for the subqueries.
*/
final def submitBroadcastJob: scala.concurrent.Future[broadcast.Broadcast[Any]] = executeQuery {
materializationStarted.set(true)
completionFuture
final def submitBroadcastJob(): scala.concurrent.Future[broadcast.Broadcast[Any]] = {
triggerFuture
scalaFuture.flatMap { _ =>
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
completionFuture
}
}(BroadcastExchangeExec.executionContext)
}

protected def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]

/**
* Cancels broadcast job with an optional reason.
*/
final def cancelBroadcastJob(reason: Option[String]): Unit = {
if (isMaterializationStarted() && !this.relationFuture.isDone) {
if (!this.relationFuture.isDone) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not re-implement broadcast cancellation, as we need more refactoring to move the creation of Future to BroadcastExchangeLike

reason match {
case Some(r) => sparkContext.cancelJobsWithTag(this.jobTag, r)
case None => sparkContext.cancelJobsWithTag(this.jobTag)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.execution.exchange

import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -36,17 +34,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
* "Volcano -- An Extensible and Parallel Query Evaluation System" by Goetz Graefe.
*/
abstract class Exchange extends UnaryExecNode {
/**
* This flag aims to detect if the stage materialization is started. This helps
* to avoid unnecessary AQE stage materialization when the stage is canceled.
*/
protected val materializationStarted = new AtomicBoolean()

/**
* Exposes status if the materialization is started
*/
def isMaterializationStarted(): Boolean = materializationStarted.get()

override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(EXCHANGE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.sql.execution.exchange

import java.util.concurrent.atomic.AtomicReference
import java.util.function.Supplier

import scala.collection.mutable
import scala.concurrent.Future
import scala.concurrent.{ExecutionContext, Future, Promise}

import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
Expand All @@ -37,25 +38,15 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.MutablePair
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.util.{MutablePair, ThreadUtils}
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
import org.apache.spark.util.random.XORShiftRandom

/**
* Common trait for all shuffle exchange implementations to facilitate pattern matching.
*/
trait ShuffleExchangeLike extends Exchange {

/**
* The asynchronous job that materializes the shuffle. It also does the preparations work,
* such as waiting for the subqueries.
*/
@transient private lazy val shuffleFuture: Future[MapOutputStatistics] = executeQuery {
materializationStarted.set(true)
mapOutputStatisticsFuture
}

/**
* Returns the number of mappers of this shuffle.
*/
Expand All @@ -76,26 +67,72 @@ trait ShuffleExchangeLike extends Exchange {
*/
def shuffleOrigin: ShuffleOrigin

@transient
private lazy val promise = Promise[MapOutputStatistics]()

@transient
private lazy val completionFuture
: scala.concurrent.Future[MapOutputStatistics] = promise.future

@transient
private[sql] // Exposed for testing
val futureAction = new AtomicReference[Option[FutureAction[MapOutputStatistics]]](None)

@transient
private var isCancelled: Boolean = false

@transient
private lazy val triggerFuture: java.util.concurrent.Future[Any] = {
SQLExecution.withThreadLocalCaptured(session, ShuffleExchangeExec.executionContext) {
try {
// Trigger shuffle preparation which can involve expensive operations like waiting on
// subqueries and file listing.
executeQuery(null)
// Submit shuffle job if not cancelled.
this.synchronized {
if (isCancelled) {
promise.tryFailure(new SparkException("Shuffle cancelled."))
} else {
val shuffleJob = RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
mapOutputStatisticsFuture
}
shuffleJob match {
case action: FutureAction[MapOutputStatistics] => futureAction.set(Some(action))
case _ =>
}
promise.completeWith(shuffleJob)
}
}
null
} catch {
case e: Throwable =>
promise.tryFailure(e)
throw e
}
}
}

/**
* Submits the shuffle job.
* The asynchronous job that materializes the shuffle. It also does the preparations work,
* such as waiting for the subqueries.
*/
final def submitShuffleJob: Future[MapOutputStatistics] = shuffleFuture

protected def mapOutputStatisticsFuture: Future[MapOutputStatistics]
final def submitShuffleJob(): Future[MapOutputStatistics] = {
triggerFuture
completionFuture
}

/**
* Cancels the shuffle job with an optional reason.
*/
final def cancelShuffleJob(reason: Option[String]): Unit = {
if (isMaterializationStarted()) {
shuffleFuture match {
case action: FutureAction[MapOutputStatistics] if !action.isCompleted =>
action.cancel(reason)
case _ =>
}
final def cancelShuffleJob(reason: Option[String]): Unit = this.synchronized {
if (!isCancelled) {
isCancelled = true
futureAction.get().foreach(_.cancel(reason))
}
}

protected def mapOutputStatisticsFuture: Future[MapOutputStatistics]

/**
* Returns the shuffle RDD with specified partition specs.
*/
Expand Down Expand Up @@ -231,6 +268,10 @@ case class ShuffleExchangeExec(

object ShuffleExchangeExec {

private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("shuffle-exchange",
SQLConf.get.getConf(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD)))

/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE
delegate.shuffleOrigin
}
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
delegate.submitShuffleJob
delegate.submitShuffleJob()
override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
delegate.getShuffleRDD(partitionSpecs)
override def runtimeStatistics: Statistics = {
Expand All @@ -1032,7 +1032,7 @@ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends Broa
override val runId: UUID = delegate.runId
override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] =
delegate.relationFuture
override def completionFuture: Future[Broadcast[Any]] = delegate.submitBroadcastJob
override def completionFuture: Future[Broadcast[Any]] = delegate.submitBroadcastJob()
override def runtimeStatistics: Statistics = delegate.runtimeStatistics
override def child: SparkPlan = delegate.child
override protected def doPrepare(): Unit = delegate.prepare()
Expand Down
Loading