Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.ExecuteHolder
import org.apache.spark.sql.connect.utils.MetricGenerator
import org.apache.spark.sql.execution.{LocalTableScanExec, SQLExecution}
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, RemoveShuffleFiles, SkipMigration, SQLExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils

Expand All @@ -58,11 +59,21 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
}
val planner = new SparkConnectPlanner(executeHolder)
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
val conf = session.sessionState.conf
val shuffleCleanupMode =
if (conf.getConf(SQLConf.SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED)) {
RemoveShuffleFiles
} else if (conf.getConf(SQLConf.SHUFFLE_DEPENDENCY_SKIP_MIGRATION_ENABLED)) {
SkipMigration
} else {
DoNotCleanup
}
val dataframe =
Dataset.ofRows(
sessionHolder.session,
planner.transformRelation(request.getPlan.getRoot),
tracker)
tracker,
shuffleCleanupMode)
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import java.nio.file.Files

import scala.collection.mutable.ArrayBuffer

import com.google.common.cache.CacheBuilder

import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.internal.{config, Logging, MDC}
Expand Down Expand Up @@ -76,13 +78,21 @@ private[spark] class IndexShuffleBlockResolver(
override def getStoredShuffles(): Seq[ShuffleBlockInfo] = {
val allBlocks = blockManager.diskBlockManager.getAllBlocks()
allBlocks.flatMap {
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
case ShuffleIndexBlockId(shuffleId, mapId, _)
if Option(shuffleIdsToSkip.getIfPresent(shuffleId)).isEmpty =>
Some(ShuffleBlockInfo(shuffleId, mapId))
case _ =>
None
}
}

private val shuffleIdsToSkip =
CacheBuilder.newBuilder().maximumSize(1000).build[java.lang.Integer, java.lang.Boolean]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the value does not matter, shall we just use Object type and always pass null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately Guava cache won't accept null values...


override def addShuffleToSkip(shuffleId: ShuffleId): Unit = {
shuffleIdsToSkip.put(shuffleId, true)
}

private def getShuffleBytesStored(): Long = {
val shuffleFiles: Seq[File] = getStoredShuffles().map {
si => getDataFile(si.shuffleId, si.mapId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ trait MigratableResolver {
*/
def getStoredShuffles(): Seq[ShuffleBlockInfo]

/**
* Mark a shuffle that should not be migrated.
*/
def addShuffleToSkip(shuffleId: Int): Unit = {}

/**
* Write a provided shuffle block as a stream. Used for block migrations.
* Up to the implementation to support STORAGE_REMOTE_SHUFFLE_MAX_DISK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ private[spark] class BlockManager(

// This is a lazy val so someone can migrating RDDs even if they don't have a MigratableResolver
// for shuffles. Used in BlockManagerDecommissioner & block puts.
private[storage] lazy val migratableResolver: MigratableResolver = {
lazy val migratableResolver: MigratableResolver = {
shuffleManager.shuffleBlockResolver.asInstanceOf[MigratableResolver]
}

Expand Down
7 changes: 5 additions & 2 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
* limitations under the License.
*/

import com.typesafe.tools.mima.core._
import com.typesafe.tools.mima.core
import com.typesafe.tools.mima.core.*

/**
* Additional excludes for checking of Spark's binary compatibility.
Expand Down Expand Up @@ -93,7 +94,9 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.TestWritable"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.TestWritable$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator$")
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator$"),
// SPARK-47764: Cleanup shuffle dependencies based on ShuffleCleanupMode
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.MigratableResolver.addShuffleToSkip")
)

// Default exclude rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,22 @@ object SQLConf {
.intConf
.createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get)

val SHUFFLE_DEPENDENCY_SKIP_MIGRATION_ENABLED =
buildConf("spark.sql.shuffleDependency.skipMigration.enabled")
.doc("When enabled, shuffle dependencies for a Spark Connect SQL execution are marked at " +
"the end of the execution, and they will not be migrated during decommissions.")
.version("4.0.0")
.booleanConf
.createWithDefault(Utils.isTesting)

val SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED =
buildConf("spark.sql.shuffleDependency.fileCleanup.enabled")
.doc("When enabled, shuffle files will be cleaned up at the end of Spark Connect " +
"SQL executions.")
.version("4.0.0")
.booleanConf
.createWithDefault(Utils.isTesting)

val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold")
.internal()
Expand Down
20 changes: 18 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,26 @@ private[sql] object Dataset {
new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
}

def ofRows(
sparkSession: SparkSession,
logicalPlan: LogicalPlan,
shuffleCleanupMode: ShuffleCleanupMode): DataFrame =
sparkSession.withActive {
val qe = new QueryExecution(
sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode)
qe.assertAnalyzed()
new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
}

/** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker)
def ofRows(
sparkSession: SparkSession,
logicalPlan: LogicalPlan,
tracker: QueryPlanningTracker,
shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup)
: DataFrame = sparkSession.withActive {
val qe = new QueryExecution(sparkSession, logicalPlan, tracker)
val qe = new QueryExecution(
sparkSession, logicalPlan, tracker, shuffleCleanupMode = shuffleCleanupMode)
qe.assertAnalyzed()
new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class QueryExecution(
val sparkSession: SparkSession,
val logical: LogicalPlan,
val tracker: QueryPlanningTracker = new QueryPlanningTracker,
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) extends Logging {
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
val shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup) extends Logging {

val id: Long = QueryExecution.nextExecutionId

Expand Down Expand Up @@ -459,6 +460,22 @@ object CommandExecutionMode extends Enumeration {
val SKIP, NON_ROOT, ALL = Value
}

/**
* Modes for shuffle dependency cleanup.
*
* DoNotCleanup: Do not perform any cleanup.
* SkipMigration: Shuffle dependencies will not be migrated at node decommissions.
* RemoveShuffleFiles: Shuffle dependency files are removed at the end of SQL executions.
*/
sealed trait ShuffleCleanupMode

case object DoNotCleanup extends ShuffleCleanupMode

case object SkipMigration extends ShuffleCleanupMode

case object RemoveShuffleFiles extends ShuffleCleanupMode


object QueryExecution {
private val _nextExecutionId = new AtomicLong(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ package org.apache.spark.sql.execution
import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkException, SparkThrowable, SparkThrowableHelper}
import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkEnv, SparkException, SparkThrowable, SparkThrowableHelper}
import org.apache.spark.SparkContext.{SPARK_JOB_DESCRIPTION, SPARK_JOB_INTERRUPT_ON_CANCEL}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, SPARK_EXECUTOR_PREFIX}
import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.SQL_EVENT_TRUNCATE_LENGTH
Expand Down Expand Up @@ -115,6 +117,7 @@ object SQLExecution extends Logging {

withSQLConfPropagated(sparkSession) {
var ex: Option[Throwable] = None
var isExecutedPlanAvailable = false
val startTime = System.nanoTime()
val startEvent = SparkListenerSQLExecutionStart(
executionId = executionId,
Expand Down Expand Up @@ -147,6 +150,7 @@ object SQLExecution extends Logging {
}
sc.listenerBus.post(
startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo))
isExecutedPlanAvailable = true
f()
}
} catch {
Expand All @@ -161,6 +165,24 @@ object SQLExecution extends Logging {
case e =>
Utils.exceptionString(e)
}
if (queryExecution.shuffleCleanupMode != DoNotCleanup
&& isExecutedPlanAvailable) {
val shuffleIds = queryExecution.executedPlan match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the root node can be a command. Shall we collect all the AdaptiveSparkPlanExec inside the plan ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is a good catch! I think we should. cc @bozhang2820

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong but I thought DataFrames for commands are created in SparkConnectPlanner, and the ones for queries are only created in SparkConnectPlanExecution?

Copy link
Contributor

@cloud-fan cloud-fan Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should clean up shuffles for CTAS and INSERT as well, as they also run queries.

case ae: AdaptiveSparkPlanExec =>
ae.context.shuffleIds.asScala.keys
case _ =>
Iterable.empty
}
shuffleIds.foreach { shuffleId =>
queryExecution.shuffleCleanupMode match {
case RemoveShuffleFiles =>
SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we call shuffleDriverComponents.removeShuffle ? We are at driver side, shuffleManager.unregisterShuffle would do nothing in non-local mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! Will fix this in a follow-up asap.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #46302.

case SkipMigration =>
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
case _ => // this should not happen
}
}
}
val event = SparkListenerSQLExecutionEnd(
executionId,
System.currentTimeMillis(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.adaptive

import java.util
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}

import scala.collection.concurrent.TrieMap
import scala.collection.mutable
Expand Down Expand Up @@ -302,6 +302,11 @@ case class AdaptiveSparkPlanExec(
try {
stage.materialize().onComplete { res =>
if (res.isSuccess) {
// record shuffle IDs for successful stages for cleanup
stage.plan.collect {
case s: ShuffleExchangeLike =>
context.shuffleIds.put(s.shuffleId, true)
}
events.offer(StageSuccess(stage, res.get))
} else {
events.offer(StageFailure(stage, res.failed.get))
Expand Down Expand Up @@ -869,6 +874,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) {
*/
val stageCache: TrieMap[SparkPlan, ExchangeQueryStageExec] =
new TrieMap[SparkPlan, ExchangeQueryStageExec]()

val shuffleIds: ConcurrentHashMap[Int, Boolean] = new ConcurrentHashMap[Int, Boolean]()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ trait ShuffleExchangeLike extends Exchange {
* Returns the runtime statistics after shuffle materialization.
*/
def runtimeStatistics: Statistics

/**
* The shuffle ID.
*/
def shuffleId: Int
}

// Describes where the shuffle operator comes from.
Expand Down Expand Up @@ -166,6 +171,8 @@ case class ShuffleExchangeExec(
Statistics(dataSize, Some(rowCount))
}

override def shuffleId: Int = shuffleDependency.shuffleId

/**
* A [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE
val attributeStats = AttributeMap(Seq((child.output.head, columnStats)))
Statistics(stats.sizeInBytes, stats.rowCount, attributeStats)
}
override def shuffleId: Int = delegate.shuffleId
override def child: SparkPlan = delegate.child
override protected def doExecute(): RDD[InternalRow] = delegate.execute()
override def outputPartitioning: Partitioning = delegate.outputPartitioning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.datasources.v2.ShowTablesExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.storage.ShuffleIndexBlockId
import org.apache.spark.util.Utils

case class QueryExecutionTestRecord(
Expand Down Expand Up @@ -314,6 +315,48 @@ class QueryExecutionSuite extends SharedSparkSession {
mockCallback.assertExecutedPlanPrepared()
}

private def cleanupShuffles(): Unit = {
val blockManager = spark.sparkContext.env.blockManager
blockManager.diskBlockManager.getAllBlocks().foreach {
case ShuffleIndexBlockId(shuffleId, _, _) =>
spark.sparkContext.env.shuffleManager.unregisterShuffle(shuffleId)
case _ =>
}
}

test("SPARK-47764: Cleanup shuffle dependencies - DoNotCleanup mode") {
val plan = spark.range(100).repartition(10).logicalPlan
val df = Dataset.ofRows(spark, plan, DoNotCleanup)
df.collect()

val blockManager = spark.sparkContext.env.blockManager
assert(blockManager.migratableResolver.getStoredShuffles().nonEmpty)
assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty)
cleanupShuffles()
}

test("SPARK-47764: Cleanup shuffle dependencies - SkipMigration mode") {
val plan = spark.range(100).repartition(10).logicalPlan
val df = Dataset.ofRows(spark, plan, SkipMigration)
df.collect()

val blockManager = spark.sparkContext.env.blockManager
assert(blockManager.migratableResolver.getStoredShuffles().isEmpty)
assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty)
cleanupShuffles()
}

test("SPARK-47764: Cleanup shuffle dependencies - RemoveShuffleFiles mode") {
val plan = spark.range(100).repartition(10).logicalPlan
val df = Dataset.ofRows(spark, plan, RemoveShuffleFiles)
df.collect()

val blockManager = spark.sparkContext.env.blockManager
assert(blockManager.migratableResolver.getStoredShuffles().isEmpty)
assert(blockManager.diskBlockManager.getAllBlocks().isEmpty)
cleanupShuffles()
}

test("SPARK-35378: Return UnsafeRow in CommandResultExecCheck execute methods") {
val plan = spark.sql("SHOW FUNCTIONS").queryExecution.executedPlan
assert(plan.isInstanceOf[CommandResultExec])
Expand Down