From e93cedd490725ea59f826e202d33dbbceceb8eba Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 20 Aug 2024 14:05:55 +0200 Subject: [PATCH 01/28] tags api --- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../CheckConnectJvmClientCompatibility.scala | 15 --- .../scala/org/apache/spark/SparkContext.scala | 22 ++++ .../apache/spark/scheduler/DAGScheduler.scala | 55 +++++++--- .../spark/scheduler/DAGSchedulerEvent.scala | 6 +- .../org/apache/spark/sql/SparkSession.scala | 103 ++++++++++++++++++ 6 files changed, 168 insertions(+), 35 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2e54617928aa..2d93ade2779c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -741,7 +741,7 @@ class SparkSession private[sql] ( * Often, a unit of execution in an application consists of multiple Spark executions. * Application programmers can use this method to group all those jobs together and give a group * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all - * running running executions with this tag. For example: + * running executions with this tag. For example: * {{{ * // In the main thread: * spark.addTag("myjobs") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 07c9e5190da0..f7cb4dd5ee35 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -402,21 +402,6 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.SparkSession.addArtifacts"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.registerClassFinder"), - // public - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptAll"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptOperation"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.removeTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.getTags"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.clearTags"), // SparkSession#Builder ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.remote"), diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 485f0abcd25e..6093b2ada9fa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2684,6 +2684,22 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelJobGroup(groupId, cancelFutureJobs = true, None) } + /** + * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`. + * + * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. + * @param reason reason for cancellation + * @param jobIdCallback callback function to be called with the job ID of each job that is being + * cancelled. + * + * @since 4.0.0 + */ + def cancelJobsWithTag(tag: String, reason: String, jobIdCallback: Int => Unit): Unit = { + SparkContext.throwIfInvalidTag(tag) + assertNotStopped() + dagScheduler.cancelJobsWithTag(tag, Option(reason)) + } + /** * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`. * @@ -2717,6 +2733,12 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelAllJobs() } + /** Cancel all jobs that have been scheduled or are running. */ + def cancelAllJobs(jobIdCallback: Int => Unit): Unit = { + assertNotStopped() + dagScheduler.cancelAllJobs(Some(jobIdCallback)) + } + /** * Cancel a given job if it's scheduled or running. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6c824e2fdeae..f194ce3f318b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1117,23 +1117,37 @@ private[spark] class DAGScheduler( /** * Cancel all jobs with a given tag. */ - def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = { + def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = + cancelJobsWithTag(tag, reason, jobIdCallback = None) + + /** + * Cancel all jobs with a given tag. + */ + def cancelJobsWithTag( + tag: String, + reason: Option[String], + jobIdCallback: Option[Int => Unit]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason)) + eventProcessLoop.post(JobTagCancelled(tag, reason, jobIdCallback)) } /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs(): Unit = { - eventProcessLoop.post(AllJobsCancelled) + def cancelAllJobs(): Unit = cancelAllJobs(jobIdCallback = None) + + /** + * Cancel all jobs that are running or waiting in the queue. + */ + def cancelAllJobs(jobIdCallback: Option[Int => Unit]): Unit = { + eventProcessLoop.post(AllJobsCancelled(jobIdCallback)) } - private[scheduler] def doCancelAllJobs(): Unit = { + private[spark] def doCancelAllJobs(jobIdCallback: Option[Int => Unit]): Unit = { // Cancel all running jobs. runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, - Option("as part of cancellation of all jobs"))) + Option("as part of cancellation of all jobs"), jobIdCallback)) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... } @@ -1231,10 +1245,13 @@ private[spark] class DAGScheduler( } val jobIds = activeInGroup.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job group %s".format(groupId)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) + jobIds.foreach(handleJobCancellation(_, Option(updatedReason), jobIdCallback = None)) } - private[scheduler] def handleJobTagCancelled(tag: String, reason: Option[String]): Unit = { + private[scheduler] def handleJobTagCancelled( + tag: String, + reason: Option[String], + jobIdCallbackOpt: Option[Int => Unit]): Unit = { // Cancel all jobs belonging that have this tag. // First finds all active jobs with this group id, and then kill stages for them. val jobIds = activeJobs.filter { activeJob => @@ -1244,7 +1261,7 @@ private[spark] class DAGScheduler( } }.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) + jobIds.foreach(handleJobCancellation(_, Option(updatedReason), jobIdCallbackOpt)) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -2801,14 +2818,17 @@ private[spark] class DAGScheduler( case None => s"because Stage $stageId was cancelled" } - handleJobCancellation(jobId, Option(reasonStr)) + handleJobCancellation(jobId, Option(reasonStr), jobIdCallback = None) } case None => logInfo(log"No active jobs to kill for Stage ${MDC(STAGE_ID, stageId)}") } } - private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]): Unit = { + private[scheduler] def handleJobCancellation( + jobId: Int, + reason: Option[String], + jobIdCallback: Option[Int => Unit]): Unit = { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { @@ -2816,6 +2836,7 @@ private[spark] class DAGScheduler( job = jobIdToActiveJob(jobId), error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) ) + jobIdCallback.foreach(_(jobId)) } } @@ -3108,16 +3129,16 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.handleStageCancellation(stageId, reason) case JobCancelled(jobId, reason) => - dagScheduler.handleJobCancellation(jobId, reason) + dagScheduler.handleJobCancellation(jobId, reason, jobIdCallback = None) case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) - case JobTagCancelled(tag, reason) => - dagScheduler.handleJobTagCancelled(tag, reason) + case JobTagCancelled(tag, reason, callback) => + dagScheduler.handleJobTagCancelled(tag, reason, callback) - case AllJobsCancelled => - dagScheduler.doCancelAllJobs() + case AllJobsCancelled(callback) => + dagScheduler.doCancelAllJobs(callback) case ExecutorAdded(execId, host) => dagScheduler.handleExecutorAdded(execId, host) @@ -3173,7 +3194,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler override def onError(e: Throwable): Unit = { logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) try { - dagScheduler.doCancelAllJobs() + dagScheduler.doCancelAllJobs(jobIdCallback = None) } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index c9ad54d1fdc7..2cd6994b7fed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -71,9 +71,11 @@ private[scheduler] case class JobGroupCancelled( private[scheduler] case class JobTagCancelled( tagName: String, - reason: Option[String]) extends DAGSchedulerEvent + reason: Option[String], + jobIdCallback: Option[Int => Unit]) extends DAGSchedulerEvent -private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent +private[scheduler] case class AllJobsCancelled( + jobIdCallback: Option[Int => Unit]) extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d64623a744fe..8be10026f473 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -22,6 +22,7 @@ import java.util.{ServiceLoader, UUID} import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -793,6 +794,108 @@ class SparkSession private( } } + + /** + * Add a tag to be assigned to all the operations started by this thread in this session. + * + * Often, a unit of execution in an application consists of multiple Spark executions. + * Application programmers can use this method to group all those jobs together and give a group + * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all + * running executions with this tag. For example: + * {{{ + * // In the main thread: + * spark.addTag("myjobs") + * spark.range(10).map(i => { Thread.sleep(10); i }).collect() + * + * // In a separate thread: + * spark.interruptTag("myjobs") + * }}} + * + * There may be multiple tags present at the same time, so different parts of application may + * use different tags to perform cancellation at different levels of granularity. + * + * @param tag + * The tag to be added. Cannot contain ',' (comma) character or be an empty string. + * + * @since 4.0.0 + */ + def addTag(tag: String): Unit = sparkContext.addJobTag(tag) + + /** + * Remove a tag previously added to be assigned to all the operations started by this thread in + * this session. Noop if such a tag was not added earlier. + * + * @param tag + * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + * + * @since 4.0.0 + */ + def removeTag(tag: String): Unit = sparkContext.removeJobTag(tag) + + /** + * Get the tags that are currently set to be assigned to all the operations started by this + * thread. + * + * @since 4.0.0 + */ + def getTags(): Set[String] = sparkContext.getJobTags() + + /** + * Clear the current thread's operation tags. + * + * @since 4.0.0 + */ + def clearTags(): Unit = sparkContext.clearJobTags() + + /** + * Interrupt all operations of this session that are currently running. + * + * @return + * sequence of Job IDs of interrupted operations. + * + * @since 4.0.0 + */ + def interruptAll(): Seq[String] = { + val jobIds = mutable.Set[Int]() + sparkContext.cancelAllJobs(jobIdCallback = (jobId: Int) => jobIds.add(jobId)) + jobIds.toSeq.map(_.toString) + } + + /** + * Interrupt all operations of this session with the given operation tag. + * + * @return + * sequence of Job IDs of interrupted operations. + * + * @since 4.0.0 + */ + def interruptTag(tag: String): Seq[String] = { + val jobIds = mutable.Set[Int]() + sparkContext.cancelJobsWithTag( + tag, + "Interrupted by user", + jobIdCallback = (jobId: Int) => jobIds.add(jobId)) + jobIds.toSeq.map(_.toString) + } + + /** + * Interrupt an operation of this session with the given Job ID. + * + * @return + * sequence of Job IDs of interrupted operations. + * + * @since 4.0.0 + */ + def interruptOperation(jobId: String): Seq[String] = { + scala.util.Try(jobId.toInt).toOption match { + case Some(id) => + sparkContext.cancelJob(id, "Interrupted by user") + Seq(jobId) + case None => + throw new IllegalArgumentException("jobId must be a number.") + } + } + /** * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a * `DataFrame`. From 40610a7c06585c52ddbea27c160dafeba8706d77 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 20 Aug 2024 14:27:22 +0200 Subject: [PATCH 02/28] rename --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f194ce3f318b..cd72972bec9c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1251,7 +1251,7 @@ private[spark] class DAGScheduler( private[scheduler] def handleJobTagCancelled( tag: String, reason: Option[String], - jobIdCallbackOpt: Option[Int => Unit]): Unit = { + jobIdCallback: Option[Int => Unit]): Unit = { // Cancel all jobs belonging that have this tag. // First finds all active jobs with this group id, and then kill stages for them. val jobIds = activeJobs.filter { activeJob => @@ -1261,7 +1261,7 @@ private[spark] class DAGScheduler( } }.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason), jobIdCallbackOpt)) + jobIds.foreach(handleJobCancellation(_, Option(updatedReason), jobIdCallback)) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { From a70d7d28b7788ecf55c21699040ec337b70213cd Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 21 Aug 2024 16:25:33 +0200 Subject: [PATCH 03/28] address comments --- .../scala/org/apache/spark/SparkContext.scala | 17 ++-- .../apache/spark/scheduler/DAGScheduler.scala | 52 ++++++------ .../spark/scheduler/DAGSchedulerEvent.scala | 5 +- .../org/apache/spark/sql/SparkSession.scala | 84 +++++++++++++++---- 4 files changed, 101 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6093b2ada9fa..dd445dd303c0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2689,15 +2689,18 @@ class SparkContext(config: SparkConf) extends Logging { * * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. * @param reason reason for cancellation - * @param jobIdCallback callback function to be called with the job ID of each job that is being - * cancelled. + * @param shouldCancelJob callback function to be called with the job ID of each job that matches + * the given tag. If the function returns true, the job will be cancelled. * * @since 4.0.0 */ - def cancelJobsWithTag(tag: String, reason: String, jobIdCallback: Int => Unit): Unit = { + def cancelJobsWithTag( + tag: String, + reason: String, + shouldCancelJob: Int => Boolean): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, Option(reason)) + dagScheduler.cancelJobsWithTag(tag, Option(reason), Some(shouldCancelJob)) } /** @@ -2733,12 +2736,6 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelAllJobs() } - /** Cancel all jobs that have been scheduled or are running. */ - def cancelAllJobs(jobIdCallback: Int => Unit): Unit = { - assertNotStopped() - dagScheduler.cancelAllJobs(Some(jobIdCallback)) - } - /** * Cancel a given job if it's scheduled or running. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cd72972bec9c..17c244d760d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1118,7 +1118,7 @@ private[spark] class DAGScheduler( * Cancel all jobs with a given tag. */ def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = - cancelJobsWithTag(tag, reason, jobIdCallback = None) + cancelJobsWithTag(tag, reason, shouldCancelJob = None) /** * Cancel all jobs with a given tag. @@ -1126,28 +1126,25 @@ private[spark] class DAGScheduler( def cancelJobsWithTag( tag: String, reason: Option[String], - jobIdCallback: Option[Int => Unit]): Unit = { + shouldCancelJob: Option[Int => Boolean]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason, jobIdCallback)) + eventProcessLoop.post(JobTagCancelled(tag, reason, shouldCancelJob)) } /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs(): Unit = cancelAllJobs(jobIdCallback = None) - - /** - * Cancel all jobs that are running or waiting in the queue. - */ - def cancelAllJobs(jobIdCallback: Option[Int => Unit]): Unit = { - eventProcessLoop.post(AllJobsCancelled(jobIdCallback)) + def cancelAllJobs(): Unit = { + eventProcessLoop.post(AllJobsCancelled) } - private[spark] def doCancelAllJobs(jobIdCallback: Option[Int => Unit]): Unit = { + private[scheduler] def doCancelAllJobs(): Unit = { // Cancel all running jobs. - runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, - Option("as part of cancellation of all jobs"), jobIdCallback)) + runningStages.map(_.firstJobId).foreach(handleJobCancellation( + _, + Option("as part of cancellation of all jobs"), + shouldCancelJob = None)) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... } @@ -1245,13 +1242,13 @@ private[spark] class DAGScheduler( } val jobIds = activeInGroup.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job group %s".format(groupId)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason), jobIdCallback = None)) + jobIds.foreach(handleJobCancellation(_, Option(updatedReason), shouldCancelJob = None)) } private[scheduler] def handleJobTagCancelled( tag: String, reason: Option[String], - jobIdCallback: Option[Int => Unit]): Unit = { + shouldCancelJob: Option[Int => Boolean]): Unit = { // Cancel all jobs belonging that have this tag. // First finds all active jobs with this group id, and then kill stages for them. val jobIds = activeJobs.filter { activeJob => @@ -1261,7 +1258,7 @@ private[spark] class DAGScheduler( } }.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason), jobIdCallback)) + jobIds.foreach(handleJobCancellation(_, Option(updatedReason), shouldCancelJob)) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -2818,7 +2815,7 @@ private[spark] class DAGScheduler( case None => s"because Stage $stageId was cancelled" } - handleJobCancellation(jobId, Option(reasonStr), jobIdCallback = None) + handleJobCancellation(jobId, Option(reasonStr), shouldCancelJob = None) } case None => logInfo(log"No active jobs to kill for Stage ${MDC(STAGE_ID, stageId)}") @@ -2828,15 +2825,16 @@ private[spark] class DAGScheduler( private[scheduler] def handleJobCancellation( jobId: Int, reason: Option[String], - jobIdCallback: Option[Int => Unit]): Unit = { + shouldCancelJob: Option[Int => Boolean]): Unit = { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - failJobAndIndependentStages( - job = jobIdToActiveJob(jobId), - error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) - ) - jobIdCallback.foreach(_(jobId)) + if (shouldCancelJob.forall(_(jobId))) { + failJobAndIndependentStages( + job = jobIdToActiveJob(jobId), + error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) + ) + } } } @@ -3129,7 +3127,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.handleStageCancellation(stageId, reason) case JobCancelled(jobId, reason) => - dagScheduler.handleJobCancellation(jobId, reason, jobIdCallback = None) + dagScheduler.handleJobCancellation(jobId, reason, shouldCancelJob = None) case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) @@ -3137,8 +3135,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobTagCancelled(tag, reason, callback) => dagScheduler.handleJobTagCancelled(tag, reason, callback) - case AllJobsCancelled(callback) => - dagScheduler.doCancelAllJobs(callback) + case AllJobsCancelled => + dagScheduler.doCancelAllJobs() case ExecutorAdded(execId, host) => dagScheduler.handleExecutorAdded(execId, host) @@ -3194,7 +3192,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler override def onError(e: Throwable): Unit = { logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) try { - dagScheduler.doCancelAllJobs(jobIdCallback = None) + dagScheduler.doCancelAllJobs() } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 2cd6994b7fed..449729704ae4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -72,10 +72,9 @@ private[scheduler] case class JobGroupCancelled( private[scheduler] case class JobTagCancelled( tagName: String, reason: Option[String], - jobIdCallback: Option[Int => Unit]) extends DAGSchedulerEvent + shouldCancelJob: Option[Int => Boolean]) extends DAGSchedulerEvent -private[scheduler] case class AllJobsCancelled( - jobIdCallback: Option[Int => Unit]) extends DAGSchedulerEvent +private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e88c8031ec80..4908b825ddcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.util.{ServiceLoader, UUID} +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} @@ -34,7 +35,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CALL_SITE_LONG_FORM, CLASS_NAME} import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT} import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.scheduler.{ActiveJob, SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ @@ -123,6 +124,15 @@ class SparkSession private( .getOrElse(SQLConf.getFallbackConf) }) + private lazy val sessionJobTag = s"spark-session-$sessionUUID" + + /** + * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. + * Real tag have the current session ID attached: + * "tag1" -> s"spark-514b7fac-51bf-458e-8bd8-0340d646326c-tag1" + */ + private val userDefinedToRealTagsMap: ConcurrentHashMap[String, String] = new ConcurrentHashMap() + /** * The version of Spark on which this application is running. * @@ -820,7 +830,10 @@ class SparkSession private( * * @since 4.0.0 */ - def addTag(tag: String): Unit = sparkContext.addJobTag(tag) + def addTag(tag: String): Unit = { + SparkContext.throwIfInvalidTag(tag) + userDefinedToRealTagsMap.put(tag, s"spark-$sessionUUID-$tag") + } /** * Remove a tag previously added to be assigned to all the operations started by this thread in @@ -831,7 +844,10 @@ class SparkSession private( * * @since 4.0.0 */ - def removeTag(tag: String): Unit = sparkContext.removeJobTag(tag) + def removeTag(tag: String): Unit = { + SparkContext.throwIfInvalidTag(tag) + userDefinedToRealTagsMap.remove(tag) + } /** * Get the tags that are currently set to be assigned to all the operations started by this @@ -839,14 +855,14 @@ class SparkSession private( * * @since 4.0.0 */ - def getTags(): Set[String] = sparkContext.getJobTags() + def getTags(): Set[String] = userDefinedToRealTagsMap.keys().asScala.toSet /** * Clear the current thread's operation tags. * * @since 4.0.0 */ - def clearTags(): Unit = sparkContext.clearJobTags() + def clearTags(): Unit = userDefinedToRealTagsMap.clear() /** * Interrupt all operations of this session that are currently running. @@ -857,9 +873,15 @@ class SparkSession private( * @since 4.0.0 */ def interruptAll(): Seq[String] = { - val jobIds = mutable.Set[Int]() - sparkContext.cancelAllJobs(jobIdCallback = (jobId: Int) => jobIds.add(jobId)) - jobIds.toSeq.map(_.toString) + val cancelledIds = mutable.Set[Int]() + sparkContext.cancelJobsWithTag( + sessionJobTag, + "Interrupted by user", + shouldCancelJob = (id: Int) => { + cancelledIds += id + true + }) + cancelledIds.toSeq.map(_.toString) } /** @@ -871,12 +893,18 @@ class SparkSession private( * @since 4.0.0 */ def interruptTag(tag: String): Seq[String] = { - val jobIds = mutable.Set[Int]() + val realTag = userDefinedToRealTagsMap.get(tag) + if (realTag == null) return Seq.empty + + val cancelledIds = mutable.Set[Int]() sparkContext.cancelJobsWithTag( - tag, + realTag, "Interrupted by user", - jobIdCallback = (jobId: Int) => jobIds.add(jobId)) - jobIds.toSeq.map(_.toString) + shouldCancelJob = (id: Int) => { + cancelledIds += id + true + }) + cancelledIds.toSeq.map(_.toString) } /** @@ -888,10 +916,22 @@ class SparkSession private( * @since 4.0.0 */ def interruptOperation(jobId: String): Seq[String] = { + val cancelledIds = mutable.Set[Int]() scala.util.Try(jobId.toInt).toOption match { - case Some(id) => - sparkContext.cancelJob(id, "Interrupted by user") - Seq(jobId) + case Some(jobIdToBeCancelled) => + sparkContext.cancelJobsWithTag( + sessionJobTag, + "Interrupted by user", + shouldCancelJob = (givenId: Int) => { + // Test all jobs owned by this session, and kill only the one with the given ID. + if (givenId == jobIdToBeCancelled) { + cancelledIds += jobIdToBeCancelled + true + } else { + false + } + }) + cancelledIds.toSeq.map(_.toString) case None => throw new IllegalArgumentException("jobId must be a number.") } @@ -1015,7 +1055,7 @@ class SparkSession private( } /** - * Execute a block of code with the this session set as the active session, and restore the + * Execute a block of code with this session set as the active session, and restore the * previous session on completion. */ private[sql] def withActive[T](block: => T): T = { @@ -1274,6 +1314,9 @@ object SparkSession extends Logging { */ def setActiveSession(session: SparkSession): Unit = { activeThreadSession.set(session) + if (session != null) { + session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) + } } /** @@ -1283,7 +1326,14 @@ object SparkSession extends Logging { * @since 2.0.0 */ def clearActiveSession(): Unit = { - activeThreadSession.remove() + getActiveSession match { + case Some(session) => + if (session != null) { + session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) + } + activeThreadSession.remove() + case None => // do nothing + } } /** From 0656e25d51918c9608d86468843d4aac57d6996e Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 21 Aug 2024 16:34:04 +0200 Subject: [PATCH 04/28] . --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 8 ++++++++ .../main/scala/org/apache/spark/sql/SparkSession.scala | 3 +-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dd445dd303c0..e508c770331a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2689,7 +2689,7 @@ class SparkContext(config: SparkConf) extends Logging { * * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. * @param reason reason for cancellation - * @param shouldCancelJob callback function to be called with the job ID of each job that matches + * @param shouldCancelJob Callback function to be called with the job ID of each job that matches * the given tag. If the function returns true, the job will be cancelled. * * @since 4.0.0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 17c244d760d4..68d82139cc22 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1116,12 +1116,20 @@ private[spark] class DAGScheduler( /** * Cancel all jobs with a given tag. + * + * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. + * @param reason reason for cancellation */ def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = cancelJobsWithTag(tag, reason, shouldCancelJob = None) /** * Cancel all jobs with a given tag. + * + * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. + * @param reason reason for cancellation + * @param shouldCancelJob Callback function to be called with the job ID of each job that matches + * the given tag. If the function returns true, the job will be cancelled. */ def cancelJobsWithTag( tag: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 4908b825ddcd..2ac9f6c69dfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -128,8 +128,7 @@ class SparkSession private( /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. - * Real tag have the current session ID attached: - * "tag1" -> s"spark-514b7fac-51bf-458e-8bd8-0340d646326c-tag1" + * Real tag have the current session ID attached: `"tag1" -> s"spark-$sessionUUID-tag1"`. */ private val userDefinedToRealTagsMap: ConcurrentHashMap[String, String] = new ConcurrentHashMap() From 6b6ca7fe46291ca6305f48f45960c0965e895bbb Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 21 Aug 2024 16:59:04 +0200 Subject: [PATCH 05/28] . --- .../src/main/scala/org/apache/spark/sql/SparkSession.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2ac9f6c69dfc..7463c86b14ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1312,8 +1312,10 @@ object SparkSession extends Logging { * @since 2.0.0 */ def setActiveSession(session: SparkSession): Unit = { + clearActiveSession() activeThreadSession.set(session) if (session != null) { + session.sparkContext.addJobTag(session.sessionJobTag) session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) } } @@ -1328,6 +1330,7 @@ object SparkSession extends Logging { getActiveSession match { case Some(session) => if (session != null) { + session.sparkContext.removeJobTag(session.sessionJobTag) session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) } activeThreadSession.remove() From d3cd5f5f6231a6e266bc978b24bb19c7a4c5124c Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Fri, 23 Aug 2024 15:30:17 +0200 Subject: [PATCH 06/28] new approach --- .../scala/org/apache/spark/SparkContext.scala | 67 +- .../apache/spark/scheduler/ActiveJob.scala | 5 +- .../apache/spark/scheduler/DAGScheduler.scala | 90 ++- .../spark/scheduler/DAGSchedulerEvent.scala | 13 +- .../apache/spark/scheduler/JobWaiter.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 54 +- .../SparkSessionJobCancellationSuite.scala | 745 ++++++++++++++++++ 8 files changed, 887 insertions(+), 91 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobCancellationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e508c770331a..c653052b5e4b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -27,6 +27,7 @@ import scala.collection.Map import scala.collection.concurrent.{Map => ScalaConcurrentMap} import scala.collection.immutable import scala.collection.mutable.HashMap +import scala.concurrent.{Future, Promise} import scala.jdk.CollectionConverters._ import scala.reflect.{classTag, ClassTag} import scala.util.control.NonFatal @@ -825,6 +826,11 @@ class SparkContext(config: SparkConf) extends Logging { def getLocalProperty(key: String): String = Option(localProperties.get).map(_.getProperty(key)).orNull + /** Set the UUID of the Spark session that starts the current job. */ + def setSparkSessionUUID(uuid: String): Unit = { + setLocalProperty(SparkContext.SPARK_SESSION_UUID, uuid) + } + /** Set a human readable description of the current job. */ def setJobDescription(value: String): Unit = { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) @@ -2691,16 +2697,20 @@ class SparkContext(config: SparkConf) extends Logging { * @param reason reason for cancellation * @param shouldCancelJob Callback function to be called with the job ID of each job that matches * the given tag. If the function returns true, the job will be cancelled. + * @return A future that will be completed with the set of job IDs that were cancelled. * * @since 4.0.0 */ def cancelJobsWithTag( tag: String, reason: String, - shouldCancelJob: Int => Boolean): Unit = { + shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, Option(reason), Some(shouldCancelJob)) + + val cancelledJobs = Promise[Set[Int]]() + dagScheduler.cancelJobsWithTag(tag, Option(reason), Some(shouldCancelJob), Some(cancelledJobs)) + cancelledJobs.future } /** @@ -2714,7 +2724,11 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String, reason: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, Option(reason)) + dagScheduler.cancelJobsWithTag( + tag, + Option(reason), + shouldCancelJob = None, + cancelledJobs = None) } /** @@ -2727,13 +2741,51 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, None) + dagScheduler.cancelJobsWithTag( + tag, + reason = None, + shouldCancelJob = None, + cancelledJobs = None) + } + + /** + * Cancel all jobs that have been scheduled or are running. + * + * @param shouldCancelJob Callback function to be called with the job ID of each job that matches + * the given tag. If the function returns true, the job will be cancelled. + * @return A future that will be completed with the set of job IDs that were cancelled. + */ + def cancelAllJobs(shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { + assertNotStopped() + + val cancelledJobs = Promise[Set[Int]]() + dagScheduler.cancelAllJobs(Some(shouldCancelJob), Some(cancelledJobs)) + cancelledJobs.future } /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs(): Unit = { assertNotStopped() - dagScheduler.cancelAllJobs() + dagScheduler.cancelAllJobs(shouldCancelJob = None, cancelledJobs = None) + } + + /** + * Cancel a given job if it's scheduled or running. + * + * @param jobId the job ID to cancel + * @param reason reason for cancellation + * @param shouldCancelJob Callback function to be called with the job ID of each job that matches + * the given tag. If the function returns true, the job will be cancelled. + * @return A future that will be completed with the set of job IDs that were cancelled. + * @note Throws `InterruptedException` if the cancel message cannot be sent + */ + def cancelJob( + jobId: Int, + reason: String, + shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { + val cancelledJobs = Promise[Set[Int]]() + dagScheduler.cancelJob(jobId, Option(reason), Some(shouldCancelJob), Some(cancelledJobs)) + cancelledJobs.future } /** @@ -2744,7 +2796,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int, reason: String): Unit = { - dagScheduler.cancelJob(jobId, Option(reason)) + dagScheduler.cancelJob(jobId, Option(reason), shouldCancelJob = None, cancelledJobs = None) } /** @@ -2754,7 +2806,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int): Unit = { - dagScheduler.cancelJob(jobId, None) + dagScheduler.cancelJob(jobId, reason = None, shouldCancelJob = None, cancelledJobs = None) } /** @@ -3103,6 +3155,7 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" private[spark] val SPARK_JOB_TAGS = "spark.job.tags" private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool" + private[spark] val SPARK_SESSION_UUID = "spark.sparkSession.uuid" private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 9876668194a8..a191320bf054 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import org.apache.spark.JobArtifactSet +import org.apache.spark.{JobArtifactSet, SparkContext} import org.apache.spark.util.CallSite /** @@ -63,4 +63,7 @@ private[spark] class ActiveJob( val finished = Array.fill[Boolean](numPartitions)(false) var numFinished = 0 + + def getSparkSessionUUID: Option[String] = + Option(properties.getProperty(SparkContext.SPARK_SESSION_UUID)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 68d82139cc22..128bae2874ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -27,6 +27,7 @@ import scala.annotation.tailrec import scala.collection.Map import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.concurrent.Promise import scala.concurrent.duration._ import scala.util.control.NonFatal @@ -144,7 +145,7 @@ private[spark] class DAGScheduler( private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) private[scheduler] val nextJobId = new AtomicInteger(0) - private[scheduler] def numTotalJobs: Int = nextJobId.get() + private[spark] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] @@ -167,7 +168,7 @@ private[spark] class DAGScheduler( // Stages that must be resubmitted due to fetch failures private[scheduler] val failedStages = new HashSet[Stage] - private[scheduler] val activeJobs = new HashSet[ActiveJob] + private[spark] val activeJobs = new HashSet[ActiveJob] // Job groups that are cancelled with `cancelFutureJobs` as true, with at most // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in @@ -1099,9 +1100,13 @@ private[spark] class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int, reason: Option[String]): Unit = { + def cancelJob( + jobId: Int, + reason: Option[String], + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]): Unit = { logInfo(log"Asked to cancel job ${MDC(JOB_ID, jobId)}") - eventProcessLoop.post(JobCancelled(jobId, reason)) + eventProcessLoop.post(JobCancelled(jobId, reason, shouldCancelJob, cancelledJobs)) } /** @@ -1114,15 +1119,6 @@ private[spark] class DAGScheduler( eventProcessLoop.post(JobGroupCancelled(groupId, cancelFutureJobs, reason)) } - /** - * Cancel all jobs with a given tag. - * - * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. - * @param reason reason for cancellation - */ - def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = - cancelJobsWithTag(tag, reason, shouldCancelJob = None) - /** * Cancel all jobs with a given tag. * @@ -1134,27 +1130,29 @@ private[spark] class DAGScheduler( def cancelJobsWithTag( tag: String, reason: Option[String], - shouldCancelJob: Option[Int => Boolean]): Unit = { + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason, shouldCancelJob)) + eventProcessLoop.post(JobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs)) } /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs(): Unit = { - eventProcessLoop.post(AllJobsCancelled) + def cancelAllJobs( + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]): Unit = { + eventProcessLoop.post(AllJobsCancelled(shouldCancelJob, cancelledJobs)) } - private[scheduler] def doCancelAllJobs(): Unit = { + def doCancelAllJobs( + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]): Unit = { // Cancel all running jobs. - runningStages.map(_.firstJobId).foreach(handleJobCancellation( - _, - Option("as part of cancellation of all jobs"), - shouldCancelJob = None)) - activeJobs.clear() // These should already be empty by this point, - jobIdToActiveJob.clear() // but just in case we lost track of some jobs... + val cancelled = runningStages.map(_.firstJobId) + .filter(doJobCancellation(_, Option("as part of cancellation of all jobs"), shouldCancelJob)) + cancelledJobs.foreach(_.success(cancelled.toSet)) } /** @@ -1250,13 +1248,14 @@ private[spark] class DAGScheduler( } val jobIds = activeInGroup.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job group %s".format(groupId)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason), shouldCancelJob = None)) + jobIds.foreach(doJobCancellation(_, Option(updatedReason), shouldCancelJob = None)) } private[scheduler] def handleJobTagCancelled( tag: String, reason: Option[String], - shouldCancelJob: Option[Int => Boolean]): Unit = { + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]): Unit = { // Cancel all jobs belonging that have this tag. // First finds all active jobs with this group id, and then kill stages for them. val jobIds = activeJobs.filter { activeJob => @@ -1266,7 +1265,8 @@ private[spark] class DAGScheduler( } }.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason), shouldCancelJob)) + val cancelled = jobIds.filter(doJobCancellation(_, Option(updatedReason), shouldCancelJob)) + cancelledJobs.foreach(_.success(cancelled.toSet)) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -2823,7 +2823,7 @@ private[spark] class DAGScheduler( case None => s"because Stage $stageId was cancelled" } - handleJobCancellation(jobId, Option(reasonStr), shouldCancelJob = None) + doJobCancellation(jobId, Option(reasonStr), shouldCancelJob = None) } case None => logInfo(log"No active jobs to kill for Stage ${MDC(STAGE_ID, stageId)}") @@ -2833,15 +2833,29 @@ private[spark] class DAGScheduler( private[scheduler] def handleJobCancellation( jobId: Int, reason: Option[String], - shouldCancelJob: Option[Int => Boolean]): Unit = { + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]): Unit = { + val cancelled = Set(jobId).filter(doJobCancellation(_, reason, shouldCancelJob)) + cancelledJobs.foreach(_.success(cancelled)) + } + + private def doJobCancellation( + jobId: Int, + reason: Option[String], + shouldCancelJob: Option[ActiveJob => Boolean]): Boolean = { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) + false } else { - if (shouldCancelJob.forall(_(jobId))) { + val activeJob = jobIdToActiveJob(jobId) + if (shouldCancelJob.forall(_(activeJob))) { failJobAndIndependentStages( - job = jobIdToActiveJob(jobId), + job = activeJob, error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) ) + true + } else { + false } } } @@ -3134,17 +3148,17 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case StageCancelled(stageId, reason) => dagScheduler.handleStageCancellation(stageId, reason) - case JobCancelled(jobId, reason) => - dagScheduler.handleJobCancellation(jobId, reason, shouldCancelJob = None) + case JobCancelled(jobId, reason, shouldCancelJob, cancelledJobs) => + dagScheduler.handleJobCancellation(jobId, reason, shouldCancelJob, cancelledJobs) case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) - case JobTagCancelled(tag, reason, callback) => - dagScheduler.handleJobTagCancelled(tag, reason, callback) + case JobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs) => + dagScheduler.handleJobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs) - case AllJobsCancelled => - dagScheduler.doCancelAllJobs() + case AllJobsCancelled(shouldCancelJob, cancelledJobs) => + dagScheduler.doCancelAllJobs(shouldCancelJob, cancelledJobs) case ExecutorAdded(execId, host) => dagScheduler.handleExecutorAdded(execId, host) @@ -3200,7 +3214,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler override def onError(e: Throwable): Unit = { logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) try { - dagScheduler.doCancelAllJobs() + dagScheduler.doCancelAllJobs(shouldCancelJob = None, cancelledJobs = None) } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 449729704ae4..d52e82c1d294 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler import java.util.Properties +import scala.concurrent.Promise + import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.util.{AccumulatorV2, CallSite} @@ -60,7 +62,9 @@ private[scheduler] case class StageCancelled( private[scheduler] case class JobCancelled( jobId: Int, - reason: Option[String]) + reason: Option[String], + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent private[scheduler] case class JobGroupCancelled( @@ -72,9 +76,12 @@ private[scheduler] case class JobGroupCancelled( private[scheduler] case class JobTagCancelled( tagName: String, reason: Option[String], - shouldCancelJob: Option[Int => Boolean]) extends DAGSchedulerEvent + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent -private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent +private[scheduler] case class AllJobsCancelled( + shouldCancelJob: Option[ActiveJob => Boolean], + cancelledJobs: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index bfd675938703..5b243f82610b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -50,7 +50,7 @@ private[spark] class JobWaiter[T]( * all the tasks belonging to this job, it will fail this job with a SparkException. */ def cancel(reason: Option[String]): Unit = { - dagScheduler.cancelJob(jobId, reason) + dagScheduler.cancelJob(jobId, reason, shouldCancelJob = None, cancelledJobs = None) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 978ceb16b376..fbd26aa52d0a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -534,7 +534,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti /** Sends JobCancelled to the DAG scheduler. */ private def cancel(jobId: Int): Unit = { - runEvent(JobCancelled(jobId, None)) + runEvent(JobCancelled(jobId, reason = None, shouldCancelJob = None, cancelledJobs = None)) } /** Make some tasks in task set success and check results. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 7463c86b14ce..6e963245d365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -23,7 +23,7 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} -import scala.collection.mutable +import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -35,7 +35,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CALL_SITE_LONG_FORM, CLASS_NAME} import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT} import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{ActiveJob, SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ @@ -58,7 +58,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.{CallSite, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ /** @@ -124,8 +124,6 @@ class SparkSession private( .getOrElse(SQLConf.getFallbackConf) }) - private lazy val sessionJobTag = s"spark-session-$sessionUUID" - /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. * Real tag have the current session ID attached: `"tag1" -> s"spark-$sessionUUID-tag1"`. @@ -831,7 +829,7 @@ class SparkSession private( */ def addTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - userDefinedToRealTagsMap.put(tag, s"spark-$sessionUUID-$tag") + userDefinedToRealTagsMap.put(tag, s"spark-session-$sessionUUID-$tag") } /** @@ -868,19 +866,11 @@ class SparkSession private( * * @return * sequence of Job IDs of interrupted operations. - * * @since 4.0.0 */ def interruptAll(): Seq[String] = { - val cancelledIds = mutable.Set[Int]() - sparkContext.cancelJobsWithTag( - sessionJobTag, - "Interrupted by user", - shouldCancelJob = (id: Int) => { - cancelledIds += id - true - }) - cancelledIds.toSeq.map(_.toString) + val cancelledIds = sparkContext.cancelAllJobs(_.getSparkSessionUUID.contains(sessionUUID)) + ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq } /** @@ -895,15 +885,8 @@ class SparkSession private( val realTag = userDefinedToRealTagsMap.get(tag) if (realTag == null) return Seq.empty - val cancelledIds = mutable.Set[Int]() - sparkContext.cancelJobsWithTag( - realTag, - "Interrupted by user", - shouldCancelJob = (id: Int) => { - cancelledIds += id - true - }) - cancelledIds.toSeq.map(_.toString) + val cancelledIds = sparkContext.cancelJobsWithTag(realTag, "Interrupted by user", _ => true) + ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq } /** @@ -915,22 +898,13 @@ class SparkSession private( * @since 4.0.0 */ def interruptOperation(jobId: String): Seq[String] = { - val cancelledIds = mutable.Set[Int]() scala.util.Try(jobId.toInt).toOption match { case Some(jobIdToBeCancelled) => - sparkContext.cancelJobsWithTag( - sessionJobTag, + val cancelledIds = sparkContext.cancelJob( + jobIdToBeCancelled, "Interrupted by user", - shouldCancelJob = (givenId: Int) => { - // Test all jobs owned by this session, and kill only the one with the given ID. - if (givenId == jobIdToBeCancelled) { - cancelledIds += jobIdToBeCancelled - true - } else { - false - } - }) - cancelledIds.toSeq.map(_.toString) + shouldCancelJob = _.getSparkSessionUUID.contains(sessionUUID)) + ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq case None => throw new IllegalArgumentException("jobId must be a number.") } @@ -1315,7 +1289,7 @@ object SparkSession extends Logging { clearActiveSession() activeThreadSession.set(session) if (session != null) { - session.sparkContext.addJobTag(session.sessionJobTag) + session.sparkContext.setSparkSessionUUID(session.sessionUUID) session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) } } @@ -1330,7 +1304,7 @@ object SparkSession extends Logging { getActiveSession match { case Some(session) => if (session != null) { - session.sparkContext.removeJobTag(session.sessionJobTag) + session.sparkContext.setSparkSessionUUID(null) session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) } activeThreadSession.remove() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobCancellationSuite.scala new file mode 100644 index 000000000000..475bd24ec48c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobCancellationSuite.scala @@ -0,0 +1,745 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{ExecutionContext, Future} +import scala.jdk.CollectionConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.logging.log4j.Level +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.internal.config.EXECUTOR_ALLOW_SPARK_CONTEXT +import org.apache.spark.internal.config.UI.UI_ENABLED +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.sql.util.ExecutionListenerBus +import org.apache.spark.tags.ExtendedSQLTest +import org.apache.spark.util.ThreadUtils + +/** + * Test cases for the cancellation APIs provided by [[SparkSession]]. + */ +@ExtendedSQLTest +class SparkSessionJobCancellationSuite + extends SparkFunSuite + with Eventually + with LocalSparkContext { + + override def afterEach(): Unit = { + try { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + resetSparkContext() + } finally { + super.afterEach() + } + } + + test("Cancellation APIs in SparkSession are isolated") { + sc = new SparkContext("local[2]", "test") + val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate() + var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) = + (null, null, null) + + // global ExecutionContext has only 2 threads in Apache Spark CI + // create own thread pool for four Futures used in this test + val numThreads = 3 + val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + + try { + // Add a listener to release the semaphore once jobs are launched. + val sem = new Semaphore(0) + val jobEnded = new AtomicInteger(0) + + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sem.release() + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + sem.release() + jobEnded.incrementAndGet() + } + }) + + // Note: since tags are added in the Future threads, they don't need to be cleared in between. + val jobA = Future { + sessionA = globalSession.cloneSession() + import globalSession.implicits._ + + assert(sessionA.getTags() == Set()) + sessionA.addTag("two") + assert(sessionA.getTags() == Set("two")) + sessionA.clearTags() // check that clearing all tags works + assert(sessionA.getTags() == Set()) + sessionA.addTag("one") + assert(sessionA.getTags() == Set("one")) + try { + sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count() + } finally { + sessionA.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobB = Future { + sessionB = globalSession.cloneSession() + import globalSession.implicits._ + + assert(sessionB.getTags() == Set()) + sessionB.addTag("one") + sessionB.addTag("two") + sessionB.addTag("one") + sessionB.addTag("two") // duplicates shouldn't matter + assert(sessionB.getTags() == Set("one", "two")) + try { + sessionB.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sessionB.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobC = Future { + sessionC = globalSession.cloneSession() + import globalSession.implicits._ + + sessionC.addTag("foo") + sessionC.removeTag("foo") + assert(sessionC.getTags() == Set()) // check that remove works removing the last tag + sessionC.addTag("boo") + try { + sessionC.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sessionC.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + + // Block until four jobs have started. + assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES)) + + // Tags are applied + val threeJobs = sc.dagScheduler.activeJobs + assert(threeJobs.size == 3) + for(ss <- Seq(sessionA, sessionB, sessionC)) { + val job = threeJobs.filter(_.getSparkSessionUUID.getOrElse("") == ss.sessionUUID) + assert(job.size == 1) + val tags = job.head.properties.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] + .split(SparkContext.SPARK_JOB_TAGS_SEP) + assert(tags.forall(_.contains(s"spark-session-${ss.sessionUUID}-"))) + val userTags = tags.map(_.replace(s"spark-session-${ss.sessionUUID}-", "")) + ss match { + case s if s == sessionA => assert(userTags.toSet == Set("one")) + case s if s == sessionB => assert(userTags.toSet == Set("one", "two")) + case s if s == sessionC => assert(userTags.toSet == Set("boo")) + } + } + + // Global session cancels nothing + assert(globalSession.interruptAll().isEmpty) + assert(globalSession.interruptTag("one").isEmpty) + assert(globalSession.interruptTag("two").isEmpty) + for (i <- 0 until globalSession.sparkContext.dagScheduler.numTotalJobs) { + assert(globalSession.interruptOperation(i.toString).isEmpty) + } + assert(jobEnded.intValue == 0) + + // One job cancelled + for (i <- 0 until globalSession.sparkContext.dagScheduler.numTotalJobs) { + sessionC.interruptOperation(i.toString) + } + val eC = intercept[SparkException] { + ThreadUtils.awaitResult(jobC, 1.minute) + }.getCause + assert(eC.getMessage contains "Interrupted") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 1) + + // Another job cancelled + assert(sessionA.interruptTag("one").size == 1) + val eA = intercept[SparkException] { + ThreadUtils.awaitResult(jobA, 1.minute) + }.getCause + assert(eA.getMessage contains "Interrupted") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 2) + + // The last job cancelled + sessionB.interruptAll() + val eB = intercept[SparkException] { + ThreadUtils.awaitResult(jobB, 1.minute) + }.getCause + assert(eB.getMessage contains "cancellation of all jobs") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 3) + } finally { + fpool.shutdownNow() + } + } + + + test("SPARK-34087: Fix memory leak of ExecutionListenerBus") { + val spark = SparkSession.builder() + .master("local") + .getOrCreate() + + @inline def listenersNum(): Int = { + spark.sparkContext + .listenerBus + .listeners + .asScala + .count(_.isInstanceOf[ExecutionListenerBus]) + } + + (1 to 10).foreach { _ => + spark.cloneSession() + SparkSession.clearActiveSession() + } + + eventually(timeout(10.seconds), interval(1.seconds)) { + System.gc() + // After GC, the number of ExecutionListenerBus should be less than 11 (we created 11 + // SparkSessions in total). + // Since GC can't 100% guarantee all out-of-referenced objects be cleaned at one time, + // here, we check at least one listener is cleaned up to prove the mechanism works. + assert(listenersNum() < 11) + } + } + + test("create with config options and propagate them to SparkContext and SparkSession") { + val session = SparkSession.builder() + .master("local") + .config(UI_ENABLED.key, value = false) + .config("some-config", "v2") + .getOrCreate() + assert(session.sparkContext.conf.get("some-config") == "v2") + assert(session.conf.get("some-config") == "v2") + } + + test("use global default session") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.builder().getOrCreate() == session) + } + + test("sets default and active session") { + assert(SparkSession.getDefaultSession == None) + assert(SparkSession.getActiveSession == None) + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.getDefaultSession == Some(session)) + assert(SparkSession.getActiveSession == Some(session)) + } + + test("get active or default session") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.active == session) + SparkSession.clearActiveSession() + assert(SparkSession.active == session) + SparkSession.clearDefaultSession() + intercept[SparkException](SparkSession.active) + session.stop() + } + + test("config options are propagated to existing SparkSession") { + val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() + assert(session1.conf.get("spark-config1") == "a") + val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate() + assert(session1 == session2) + assert(session1.conf.get("spark-config1") == "b") + } + + test("use session from active thread session and propagate config options") { + val defaultSession = SparkSession.builder().master("local").getOrCreate() + val activeSession = defaultSession.newSession() + SparkSession.setActiveSession(activeSession) + val session = SparkSession.builder().config("spark-config2", "a").getOrCreate() + + assert(activeSession != defaultSession) + assert(session == activeSession) + assert(session.conf.get("spark-config2") == "a") + assert(session.sessionState.conf == SQLConf.get) + assert(SQLConf.get.getConfString("spark-config2") == "a") + SparkSession.clearActiveSession() + + assert(SparkSession.builder().getOrCreate() == defaultSession) + } + + test("create a new session if the default session has been stopped") { + val defaultSession = SparkSession.builder().master("local").getOrCreate() + SparkSession.setDefaultSession(defaultSession) + defaultSession.stop() + val newSession = SparkSession.builder().master("local").getOrCreate() + assert(newSession != defaultSession) + } + + test("create a new session if the active thread session has been stopped") { + val activeSession = SparkSession.builder().master("local").getOrCreate() + SparkSession.setActiveSession(activeSession) + activeSession.stop() + val newSession = SparkSession.builder().master("local").getOrCreate() + assert(newSession != activeSession) + } + + test("create SparkContext first then SparkSession") { + val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") + val sparkContext2 = new SparkContext(conf) + val session = SparkSession.builder().config("key2", "value2").getOrCreate() + assert(session.conf.get("key1") == "value1") + assert(session.conf.get("key2") == "value2") + assert(session.sparkContext == sparkContext2) + // We won't update conf for existing `SparkContext` + assert(!sparkContext2.conf.contains("key2")) + assert(sparkContext2.conf.get("key1") == "value1") + } + + test("create SparkContext first then pass context to SparkSession") { + val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") + val newSC = new SparkContext(conf) + val session = SparkSession.builder().sparkContext(newSC).config("key2", "value2").getOrCreate() + assert(session.conf.get("key1") == "value1") + assert(session.conf.get("key2") == "value2") + assert(session.sparkContext == newSC) + assert(session.sparkContext.conf.get("key1") == "value1") + // If the created sparkContext is passed through the Builder's API sparkContext, + // the conf of this sparkContext will not contain the conf set through the API config. + assert(!session.sparkContext.conf.contains("key2")) + assert(session.sparkContext.conf.get("spark.app.name") == "test") + } + + test("SPARK-15887: hive-site.xml should be loaded") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") + assert(session.sparkContext.hadoopConfiguration.get("hive.in.test") == "true") + } + + test("SPARK-15991: Set global Hadoop conf") { + val session = SparkSession.builder().master("local").getOrCreate() + val mySpecialKey = "my.special.key.15991" + val mySpecialValue = "msv" + try { + session.sparkContext.hadoopConfiguration.set(mySpecialKey, mySpecialValue) + assert(session.sessionState.newHadoopConf().get(mySpecialKey) == mySpecialValue) + } finally { + session.sparkContext.hadoopConfiguration.unset(mySpecialKey) + } + } + + test("SPARK-31234: RESET command will not change static sql configs and " + + "spark context conf values in SessionState") { + val session = SparkSession.builder() + .master("local") + .config(GLOBAL_TEMP_DATABASE.key, value = "globalTempDB-SPARK-31234") + .config("spark.app.name", "test-app-SPARK-31234") + .getOrCreate() + + assert(session.conf.get("spark.app.name") === "test-app-SPARK-31234") + assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234") + session.sql("RESET") + assert(session.conf.get("spark.app.name") === "test-app-SPARK-31234") + assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234") + } + + test("SPARK-31354: SparkContext only register one SparkSession ApplicationEnd listener") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test-app-SPARK-31354-1") + val context = new SparkContext(conf) + SparkSession + .builder() + .sparkContext(context) + .master("local") + .getOrCreate() + val postFirstCreation = context.listenerBus.listeners.size() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + + SparkSession + .builder() + .sparkContext(context) + .master("local") + .getOrCreate() + val postSecondCreation = context.listenerBus.listeners.size() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + assert(postFirstCreation == postSecondCreation) + } + + test("SPARK-31532: should not propagate static sql configs to the existing" + + " active/default SparkSession") { + val session = SparkSession.builder() + .master("local") + .config(GLOBAL_TEMP_DATABASE.key, value = "globalTempDB-SPARK-31532") + .config("spark.app.name", "test-app-SPARK-31532") + .getOrCreate() + // do not propagate static sql configs to the existing active session + val session1 = SparkSession + .builder() + .config(GLOBAL_TEMP_DATABASE.key, "globalTempDB-SPARK-31532-1") + .getOrCreate() + assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") + assert(session1.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") + + // do not propagate static sql configs to the existing default session + SparkSession.clearActiveSession() + val session2 = SparkSession + .builder() + .config(WAREHOUSE_PATH.key, "SPARK-31532-db") + .config(GLOBAL_TEMP_DATABASE.key, value = "globalTempDB-SPARK-31532-2") + .getOrCreate() + + assert(!session.conf.get(WAREHOUSE_PATH).contains("SPARK-31532-db")) + assert(session.conf.get(WAREHOUSE_PATH) === session2.conf.get(WAREHOUSE_PATH)) + assert(session2.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") + } + + test("SPARK-31532: propagate static sql configs if no existing SparkSession") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test-app-SPARK-31532-2") + .set(GLOBAL_TEMP_DATABASE.key, "globaltempdb-spark-31532") + .set(WAREHOUSE_PATH.key, "SPARK-31532-db") + SparkContext.getOrCreate(conf) + + // propagate static sql configs if no existing session + val session = SparkSession + .builder() + .config(GLOBAL_TEMP_DATABASE.key, "globalTempDB-SPARK-31532-2") + .config(WAREHOUSE_PATH.key, "SPARK-31532-db-2") + .getOrCreate() + assert(session.conf.get("spark.app.name") === "test-app-SPARK-31532-2") + assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532-2") + assert(session.conf.get(WAREHOUSE_PATH) contains "SPARK-31532-db-2") + } + + test("SPARK-32062: reset listenerRegistered in SparkSession") { + (1 to 2).foreach { i => + val conf = new SparkConf() + .setMaster("local") + .setAppName(s"test-SPARK-32062-$i") + val context = new SparkContext(conf) + val beforeListenerSize = context.listenerBus.listeners.size() + SparkSession + .builder() + .sparkContext(context) + .getOrCreate() + val afterListenerSize = context.listenerBus.listeners.size() + assert(beforeListenerSize + 1 == afterListenerSize) + context.stop() + } + } + + test("SPARK-32160: Disallow to create SparkSession in executors") { + val session = SparkSession.builder().master("local-cluster[3, 1, 1024]").getOrCreate() + + val error = intercept[SparkException] { + session.range(1).foreach { v => + SparkSession.builder().master("local").getOrCreate() + () + } + }.getMessage() + + assert(error.contains("SparkSession should only be created and accessed on the driver.")) + } + + test("SPARK-32160: Allow to create SparkSession in executors if the config is set") { + val session = SparkSession.builder().master("local-cluster[3, 1, 1024]").getOrCreate() + + session.range(1).foreach { v => + SparkSession.builder().master("local") + .config(EXECUTOR_ALLOW_SPARK_CONTEXT.key, true).getOrCreate().stop() + () + } + } + + test("SPARK-32991: Use conf in shared state as the original configuration for RESET") { + val wh = "spark.sql.warehouse.dir" + val td = "spark.sql.globalTempDatabase" + val custom = "spark.sql.custom" + + val conf = new SparkConf() + .setMaster("local") + .setAppName("SPARK-32991") + .set(wh, "./data1") + .set(td, "bob") + + val sc = new SparkContext(conf) + + val spark = SparkSession.builder() + .config(wh, "./data2") + .config(td, "alice") + .config(custom, "kyao") + .getOrCreate() + + // When creating the first session like above, we will update the shared spark conf to the + // newly specified values + val sharedWH = spark.sharedState.conf.get(wh) + val sharedTD = spark.sharedState.conf.get(td) + assert(sharedWH contains "data2", + "The warehouse dir in shared state should be determined by the 1st created spark session") + assert(sharedTD === "alice", + "Static sql configs in shared state should be determined by the 1st created spark session") + assert(spark.sharedState.conf.getOption(custom).isEmpty, + "Dynamic sql configs is session specific") + + assert(spark.conf.get(wh) contains sharedWH, + "The warehouse dir in session conf and shared state conf should be consistent") + assert(spark.conf.get(td) === sharedTD, + "Static sql configs in session conf and shared state conf should be consistent") + assert(spark.conf.get(custom) === "kyao", "Dynamic sql configs is session specific") + + spark.sql("RESET") + + assert(spark.conf.get(wh) contains sharedWH, + "The warehouse dir in shared state should be respect after RESET") + assert(spark.conf.get(td) === sharedTD, + "Static sql configs in shared state should be respect after RESET") + assert(spark.conf.get(custom) === "kyao", + "Dynamic sql configs in session initial map should be respect after RESET") + + val spark2 = SparkSession.builder() + .config(wh, "./data3") + .config(custom, "kyaoo").getOrCreate() + assert(spark2.conf.get(wh) contains sharedWH) + assert(spark2.conf.get(td) === sharedTD) + assert(spark2.conf.get(custom) === "kyaoo") + } + + test("SPARK-32991: RESET should work properly with multi threads") { + val wh = "spark.sql.warehouse.dir" + val td = "spark.sql.globalTempDatabase" + val custom = "spark.sql.custom" + val spark = ThreadUtils.runInNewThread("new session 0", false) { + SparkSession.builder() + .master("local") + .config(wh, "./data0") + .config(td, "bob") + .config(custom, "c0") + .getOrCreate() + } + + spark.sql(s"SET $custom=c1") + assert(spark.conf.get(custom) === "c1") + spark.sql("RESET") + assert(spark.conf.get(wh) contains "data0", + "The warehouse dir in shared state should be respect after RESET") + assert(spark.conf.get(td) === "bob", + "Static sql configs in shared state should be respect after RESET") + assert(spark.conf.get(custom) === "c0", + "Dynamic sql configs in shared state should be respect after RESET") + + val spark1 = ThreadUtils.runInNewThread("new session 1", false) { + SparkSession.builder().getOrCreate() + } + + assert(spark === spark1) + + // TODO: SPARK-33718: After clear sessions, the SharedState will be unreachable, then all + // the new static will take effect. + SparkSession.clearDefaultSession() + val spark2 = ThreadUtils.runInNewThread("new session 2", false) { + SparkSession.builder() + .master("local") + .config(wh, "./data1") + .config(td, "alice") + .config(custom, "c2") + .getOrCreate() + } + + assert(spark2 !== spark) + spark2.sql(s"SET $custom=c1") + assert(spark2.conf.get(custom) === "c1") + spark2.sql("RESET") + assert(spark2.conf.get(wh) contains "data1") + assert(spark2.conf.get(td) === "alice") + assert(spark2.conf.get(custom) === "c2") + + } + + test("SPARK-33944: warning setting hive.metastore.warehouse.dir using session options") { + val msg = "Not allowing to set hive.metastore.warehouse.dir in SparkSession's options" + val logAppender = new LogAppender(msg) + withLogAppender(logAppender) { + SparkSession.builder() + .master("local") + .config("hive.metastore.warehouse.dir", "any") + .getOrCreate() + .sharedState + } + assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) + } + + test("SPARK-33944: no warning setting spark.sql.warehouse.dir using session options") { + val msg = "Not allowing to set hive.metastore.warehouse.dir in SparkSession's options" + val logAppender = new LogAppender(msg) + withLogAppender(logAppender) { + SparkSession.builder() + .master("local") + .config("spark.sql.warehouse.dir", "any") + .getOrCreate() + .sharedState + } + assert(!logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) + } + + Seq(".", "..", "dir0", "dir0/dir1", "/dir0/dir1", "./dir0").foreach { pathStr => + test(s"SPARK-34558: warehouse path ($pathStr) should be qualified for spark/hadoop conf") { + val path = new Path(pathStr) + val conf = new SparkConf().set(WAREHOUSE_PATH, pathStr) + val session = SparkSession.builder() + .master("local") + .config(conf) + .getOrCreate() + val hadoopConf = session.sessionState.newHadoopConf() + val expected = path.getFileSystem(hadoopConf).makeQualified(path).toString + // session related configs + assert(hadoopConf.get("hive.metastore.warehouse.dir") === expected) + assert(session.conf.get(WAREHOUSE_PATH) === expected) + assert(session.sessionState.conf.warehousePath === expected) + + // shared configs + assert(session.sharedState.conf.get(WAREHOUSE_PATH) === expected) + assert(session.sharedState.hadoopConf.get("hive.metastore.warehouse.dir") === expected) + + // spark context configs + assert(session.sparkContext.conf.get(WAREHOUSE_PATH) === expected) + assert(session.sparkContext.hadoopConfiguration.get("hive.metastore.warehouse.dir") === + expected) + } + } + + test("SPARK-34558: Create a working SparkSession with a broken FileSystem") { + val msg = "Cannot qualify the warehouse path, leaving it unqualified" + val logAppender = new LogAppender(msg) + withLogAppender(logAppender) { + val session = + SparkSession.builder() + .master("local") + .config(WAREHOUSE_PATH.key, "unknown:///mydir") + .getOrCreate() + session.sql("SELECT 1").collect() + } + assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) + } + + test("SPARK-37727: Show ignored configurations in debug level logs") { + // Create one existing SparkSession to check following logs. + SparkSession.builder().master("local").getOrCreate() + + val logAppender = new LogAppender + logAppender.setThreshold(Level.DEBUG) + withLogAppender(logAppender, level = Some(Level.DEBUG)) { + SparkSession.builder() + .config("spark.sql.warehouse.dir", "2") + .config("spark.abc", "abcb") + .config("spark.abcd", "abcb4") + .getOrCreate() + } + + val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) + Seq( + "Ignored static SQL configurations", + "spark.sql.warehouse.dir=2", + "Configurations that might not take effect", + "spark.abcd=abcb4", + "spark.abc=abcb").foreach { msg => + assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") + } + } + + test("SPARK-37727: Hide the same configuration already explicitly set in logs") { + // Create one existing SparkSession to check following logs. + SparkSession.builder().master("local").config("spark.abc", "abc").getOrCreate() + + val logAppender = new LogAppender + logAppender.setThreshold(Level.DEBUG) + withLogAppender(logAppender, level = Some(Level.DEBUG)) { + // Ignore logs because it's already set. + SparkSession.builder().config("spark.abc", "abc").getOrCreate() + // Show logs for only configuration newly set. + SparkSession.builder().config("spark.abc.new", "abc").getOrCreate() + // Ignore logs because it's set ^. + SparkSession.builder().config("spark.abc.new", "abc").getOrCreate() + } + + val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) + Seq( + "Using an existing Spark session; only runtime SQL configurations will take effect", + "Configurations that might not take effect", + "spark.abc.new=abc").foreach { msg => + assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") + } + + assert( + !logs.exists(_.contains("spark.abc=abc")), + s"'spark.abc=abc' existed in:\n${logs.mkString("\n")}") + } + + test("SPARK-37727: Hide runtime SQL configurations in logs") { + // Create one existing SparkSession to check following logs. + SparkSession.builder().master("local").getOrCreate() + + val logAppender = new LogAppender + logAppender.setThreshold(Level.DEBUG) + withLogAppender(logAppender, level = Some(Level.DEBUG)) { + // Ignore logs for runtime SQL configurations + SparkSession.builder().config("spark.sql.ansi.enabled", "true").getOrCreate() + // Show logs for Spark core configuration + SparkSession.builder().config("spark.buffer.size", "1234").getOrCreate() + // Show logs for custom runtime options + SparkSession.builder().config("spark.sql.source.abc", "abc").getOrCreate() + // Show logs for static SQL configurations + SparkSession.builder().config("spark.sql.warehouse.dir", "xyz").getOrCreate() + } + + val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) + Seq( + "spark.buffer.size=1234", + "spark.sql.source.abc=abc", + "spark.sql.warehouse.dir=xyz").foreach { msg => + assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") + } + + assert( + !logs.exists(_.contains("spark.sql.ansi.enabled\"")), + s"'spark.sql.ansi.enabled' existed in:\n${logs.mkString("\n")}") + } + + test("SPARK-40163: SparkSession.config(Map)") { + val map: Map[String, Any] = Map( + "string" -> "", + "boolean" -> true, + "double" -> 0.0, + "long" -> 0L + ) + + val session = SparkSession.builder() + .master("local") + .config(map) + .getOrCreate() + + for (e <- map) { + assert(session.conf.get(e._1) == e._2.toString) + } + } +} From 0922dd287c4f5de18b21cc3393887f7dd8282a85 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Mon, 26 Aug 2024 16:58:01 +0200 Subject: [PATCH 07/28] address comments --- .../scala/org/apache/spark/SparkContext.scala | 8 ++--- .../org/apache/spark/sql/SparkSession.scala | 30 +++++++++++++----- ...ssionJobTaggingAndCancellationSuite.scala} | 31 +++++++++++++++++-- 3 files changed, 55 insertions(+), 14 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/{SparkSessionJobCancellationSuite.scala => SparkSessionJobTaggingAndCancellationSuite.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c653052b5e4b..d04a2b55ce63 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2698,10 +2698,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param shouldCancelJob Callback function to be called with the job ID of each job that matches * the given tag. If the function returns true, the job will be cancelled. * @return A future that will be completed with the set of job IDs that were cancelled. - * - * @since 4.0.0 */ - def cancelJobsWithTag( + private[spark] def cancelJobsWithTag( tag: String, reason: String, shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { @@ -2755,7 +2753,7 @@ class SparkContext(config: SparkConf) extends Logging { * the given tag. If the function returns true, the job will be cancelled. * @return A future that will be completed with the set of job IDs that were cancelled. */ - def cancelAllJobs(shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { + private[spark] def cancelAllJobs(shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { assertNotStopped() val cancelledJobs = Promise[Set[Int]]() @@ -2779,7 +2777,7 @@ class SparkContext(config: SparkConf) extends Logging { * @return A future that will be completed with the set of job IDs that were cancelled. * @note Throws `InterruptedException` if the cancel message cannot be sent */ - def cancelJob( + private[spark] def cancelJob( jobId: Int, reason: String, shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 6e963245d365..e31d3222bfef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -93,7 +93,8 @@ class SparkSession private( @transient private val existingSharedState: Option[SharedState], @transient private val parentSessionState: Option[SessionState], @transient private[sql] val extensions: SparkSessionExtensions, - @transient private[sql] val initialSessionOptions: Map[String, String]) + @transient private[sql] val initialSessionOptions: Map[String, String], + @transient private val parentUserDefinedToRealTagsMap: Map[String, String]) extends Serializable with Closeable with Logging { self => // The call site where this SparkSession was constructed. @@ -108,8 +109,13 @@ class SparkSession private( private[sql] def this( sc: SparkContext, initialSessionOptions: java.util.HashMap[String, String]) = { - this(sc, None, None, SparkSession.applyExtensions(sc, new SparkSessionExtensions), - initialSessionOptions.asScala.toMap) + this( + sc, + existingSharedState = None, + parentSessionState = None, + SparkSession.applyExtensions(sc, new SparkSessionExtensions), + initialSessionOptions.asScala.toMap, + parentUserDefinedToRealTagsMap = Map.empty) } private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]()) @@ -128,7 +134,9 @@ class SparkSession private( * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. * Real tag have the current session ID attached: `"tag1" -> s"spark-$sessionUUID-tag1"`. */ - private val userDefinedToRealTagsMap: ConcurrentHashMap[String, String] = new ConcurrentHashMap() + @transient + private lazy val userDefinedToRealTagsMap: ConcurrentHashMap[String, String] = + new ConcurrentHashMap(parentUserDefinedToRealTagsMap.asJava) /** * The version of Spark on which this application is running. @@ -285,7 +293,8 @@ class SparkSession private( Some(sharedState), parentSessionState = None, extensions, - initialSessionOptions) + initialSessionOptions, + parentUserDefinedToRealTagsMap = Map.empty) } /** @@ -306,8 +315,10 @@ class SparkSession private( Some(sharedState), Some(sessionState), extensions, - Map.empty) + Map.empty, + userDefinedToRealTagsMap.asScala.toMap) result.sessionState // force copy of SessionState + result.userDefinedToRealTagsMap // force copy of userDefinedToRealTagsMap result } @@ -1261,7 +1272,12 @@ object SparkSession extends Logging { loadExtensions(extensions) applyExtensions(sparkContext, extensions) - session = new SparkSession(sparkContext, None, None, extensions, options.toMap) + session = new SparkSession(sparkContext, + existingSharedState = None, + parentSessionState = None, + extensions, + initialSessionOptions = options.toMap, + parentUserDefinedToRealTagsMap = Map.empty) setDefaultSession(session) setActiveSession(session) registerContextListener(sparkContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobCancellationSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 475bd24ec48c..0b0081bcaac4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -39,10 +39,10 @@ import org.apache.spark.tags.ExtendedSQLTest import org.apache.spark.util.ThreadUtils /** - * Test cases for the cancellation APIs provided by [[SparkSession]]. + * Test cases for the tagging and cancellation APIs provided by [[SparkSession]]. */ @ExtendedSQLTest -class SparkSessionJobCancellationSuite +class SparkSessionJobTaggingAndCancellationSuite extends SparkFunSuite with Eventually with LocalSparkContext { @@ -60,6 +60,33 @@ class SparkSessionJobCancellationSuite } } + test("Tags are not inherited by new sessions") { + val session = SparkSession.builder().master("local").getOrCreate() + + assert(session.getTags() == Set()) + session.addTag("one") + assert(session.getTags() == Set("one")) + + val newSession = session.newSession() + assert(newSession.getTags() == Set()) + } + + test("Tags are inherited by cloned sessions") { + val session = SparkSession.builder().master("local").getOrCreate() + + assert(session.getTags() == Set()) + session.addTag("one") + assert(session.getTags() == Set("one")) + + val clonedSession = session.cloneSession() + assert(clonedSession.getTags() == Set("one")) + clonedSession.addTag("two") + assert(clonedSession.getTags() == Set("one", "two")) + + // Tags are not propagated back to the original session + assert(session.getTags() == Set("one")) + } + test("Cancellation APIs in SparkSession are isolated") { sc = new SparkContext("local[2]", "test") val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate() From 2a6fcc694a748ef70cd5a918d7e98ade10dc98ee Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Mon, 26 Aug 2024 17:58:30 +0200 Subject: [PATCH 08/28] return job IDs earlier --- .../scala/org/apache/spark/SparkContext.scala | 10 +-- .../apache/spark/scheduler/DAGScheduler.scala | 79 ++++++++++--------- .../spark/scheduler/DAGSchedulerEvent.scala | 6 +- .../apache/spark/scheduler/JobWaiter.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- 5 files changed, 51 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d04a2b55ce63..5b22ab5dfac5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2726,7 +2726,7 @@ class SparkContext(config: SparkConf) extends Logging { tag, Option(reason), shouldCancelJob = None, - cancelledJobs = None) + jobsToBeCancelled = None) } /** @@ -2743,7 +2743,7 @@ class SparkContext(config: SparkConf) extends Logging { tag, reason = None, shouldCancelJob = None, - cancelledJobs = None) + jobsToBeCancelled = None) } /** @@ -2764,7 +2764,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs(): Unit = { assertNotStopped() - dagScheduler.cancelAllJobs(shouldCancelJob = None, cancelledJobs = None) + dagScheduler.cancelAllJobs(shouldCancelJob = None, jobsToBeCancelled = None) } /** @@ -2794,7 +2794,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int, reason: String): Unit = { - dagScheduler.cancelJob(jobId, Option(reason), shouldCancelJob = None, cancelledJobs = None) + dagScheduler.cancelJob(jobId, Option(reason), shouldCancelJob = None, jobsToBeCancelled = None) } /** @@ -2804,7 +2804,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int): Unit = { - dagScheduler.cancelJob(jobId, reason = None, shouldCancelJob = None, cancelledJobs = None) + dagScheduler.cancelJob(jobId, reason = None, shouldCancelJob = None, jobsToBeCancelled = None) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 128bae2874ad..8cbcef23398d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1104,9 +1104,9 @@ private[spark] class DAGScheduler( jobId: Int, reason: Option[String], shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]): Unit = { + jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { logInfo(log"Asked to cancel job ${MDC(JOB_ID, jobId)}") - eventProcessLoop.post(JobCancelled(jobId, reason, shouldCancelJob, cancelledJobs)) + eventProcessLoop.post(JobCancelled(jobId, reason, shouldCancelJob, jobsToBeCancelled)) } /** @@ -1131,10 +1131,10 @@ private[spark] class DAGScheduler( tag: String, reason: Option[String], shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]): Unit = { + jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs)) + eventProcessLoop.post(JobTagCancelled(tag, reason, shouldCancelJob, jobsToBeCancelled)) } /** @@ -1142,17 +1142,17 @@ private[spark] class DAGScheduler( */ def cancelAllJobs( shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]): Unit = { - eventProcessLoop.post(AllJobsCancelled(shouldCancelJob, cancelledJobs)) + jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { + eventProcessLoop.post(AllJobsCancelled(shouldCancelJob, jobsToBeCancelled)) } def doCancelAllJobs( shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]): Unit = { + jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { // Cancel all running jobs. - val cancelled = runningStages.map(_.firstJobId) - .filter(doJobCancellation(_, Option("as part of cancellation of all jobs"), shouldCancelJob)) - cancelledJobs.foreach(_.success(cancelled.toSet)) + val jobIds = runningStages.map(_.firstJobId).filter(shouldCancelJobId(_, shouldCancelJob)) + jobsToBeCancelled.foreach(_.success(jobIds.toSet)) + jobIds.foreach(doJobCancellation(_, Option("as part of cancellation of all jobs"))) } /** @@ -1248,14 +1248,14 @@ private[spark] class DAGScheduler( } val jobIds = activeInGroup.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job group %s".format(groupId)) - jobIds.foreach(doJobCancellation(_, Option(updatedReason), shouldCancelJob = None)) + jobIds.foreach(doJobCancellation(_, Option(updatedReason))) } private[scheduler] def handleJobTagCancelled( tag: String, reason: Option[String], shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]): Unit = { + jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { // Cancel all jobs belonging that have this tag. // First finds all active jobs with this group id, and then kill stages for them. val jobIds = activeJobs.filter { activeJob => @@ -1265,8 +1265,9 @@ private[spark] class DAGScheduler( } }.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - val cancelled = jobIds.filter(doJobCancellation(_, Option(updatedReason), shouldCancelJob)) - cancelledJobs.foreach(_.success(cancelled.toSet)) + val idsToBeCancelled = jobIds.filter(shouldCancelJobId(_, shouldCancelJob)) + jobsToBeCancelled.foreach(_.success(idsToBeCancelled.toSet)) + idsToBeCancelled.foreach(doJobCancellation(_, Option(updatedReason))) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -2823,7 +2824,7 @@ private[spark] class DAGScheduler( case None => s"because Stage $stageId was cancelled" } - doJobCancellation(jobId, Option(reasonStr), shouldCancelJob = None) + doJobCancellation(jobId, Option(reasonStr)) } case None => logInfo(log"No active jobs to kill for Stage ${MDC(STAGE_ID, stageId)}") @@ -2834,29 +2835,31 @@ private[spark] class DAGScheduler( jobId: Int, reason: Option[String], shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]): Unit = { - val cancelled = Set(jobId).filter(doJobCancellation(_, reason, shouldCancelJob)) - cancelledJobs.foreach(_.success(cancelled)) + jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { + val shouldCancel = shouldCancelJobId(jobId, shouldCancelJob) + if (shouldCancel) { + jobsToBeCancelled.foreach(_.success(Set(jobId))) + doJobCancellation(jobId, reason) + } else { + jobsToBeCancelled.foreach(_.success(Set.empty)) + } } - - private def doJobCancellation( + + private def shouldCancelJobId( jobId: Int, - reason: Option[String], shouldCancelJob: Option[ActiveJob => Boolean]): Boolean = { + shouldCancelJob.forall(_(jobIdToActiveJob(jobId))) + } + + private def doJobCancellation(jobId: Int, reason: Option[String]): Unit = { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) - false } else { val activeJob = jobIdToActiveJob(jobId) - if (shouldCancelJob.forall(_(activeJob))) { - failJobAndIndependentStages( - job = activeJob, - error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) - ) - true - } else { - false - } + failJobAndIndependentStages( + job = activeJob, + error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) + ) } } @@ -3148,17 +3151,17 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case StageCancelled(stageId, reason) => dagScheduler.handleStageCancellation(stageId, reason) - case JobCancelled(jobId, reason, shouldCancelJob, cancelledJobs) => - dagScheduler.handleJobCancellation(jobId, reason, shouldCancelJob, cancelledJobs) + case JobCancelled(jobId, reason, shouldCancelJob, jobsToBeCancelled) => + dagScheduler.handleJobCancellation(jobId, reason, shouldCancelJob, jobsToBeCancelled) case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) - case JobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs) => - dagScheduler.handleJobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs) + case JobTagCancelled(tag, reason, shouldCancelJob, jobsToBeCancelled) => + dagScheduler.handleJobTagCancelled(tag, reason, shouldCancelJob, jobsToBeCancelled) - case AllJobsCancelled(shouldCancelJob, cancelledJobs) => - dagScheduler.doCancelAllJobs(shouldCancelJob, cancelledJobs) + case AllJobsCancelled(shouldCancelJob, jobsToBeCancelled) => + dagScheduler.doCancelAllJobs(shouldCancelJob, jobsToBeCancelled) case ExecutorAdded(execId, host) => dagScheduler.handleExecutorAdded(execId, host) @@ -3214,7 +3217,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler override def onError(e: Throwable): Unit = { logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) try { - dagScheduler.doCancelAllJobs(shouldCancelJob = None, cancelledJobs = None) + dagScheduler.doCancelAllJobs(shouldCancelJob = None, jobsToBeCancelled = None) } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index d52e82c1d294..947dc633ac5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -64,7 +64,7 @@ private[scheduler] case class JobCancelled( jobId: Int, reason: Option[String], shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]) + jobsToBeCancelled: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent private[scheduler] case class JobGroupCancelled( @@ -77,11 +77,11 @@ private[scheduler] case class JobTagCancelled( tagName: String, reason: Option[String], shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent + jobsToBeCancelled: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent private[scheduler] case class AllJobsCancelled( shouldCancelJob: Option[ActiveJob => Boolean], - cancelledJobs: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent + jobsToBeCancelled: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 5b243f82610b..7df0a74317c7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -50,7 +50,7 @@ private[spark] class JobWaiter[T]( * all the tasks belonging to this job, it will fail this job with a SparkException. */ def cancel(reason: Option[String]): Unit = { - dagScheduler.cancelJob(jobId, reason, shouldCancelJob = None, cancelledJobs = None) + dagScheduler.cancelJob(jobId, reason, shouldCancelJob = None, jobsToBeCancelled = None) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index fbd26aa52d0a..447004cd7e64 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -534,7 +534,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti /** Sends JobCancelled to the DAG scheduler. */ private def cancel(jobId: Int): Unit = { - runEvent(JobCancelled(jobId, reason = None, shouldCancelJob = None, cancelledJobs = None)) + runEvent(JobCancelled(jobId, reason = None, shouldCancelJob = None, jobsToBeCancelled = None)) } /** Make some tasks in task set success and check results. */ From ef0fddf802f69f4fc369b407b27186db62e3c1ff Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Mon, 26 Aug 2024 18:18:33 +0200 Subject: [PATCH 09/28] doc --- .../org/apache/spark/sql/SparkSession.scala | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e31d3222bfef..f0fb8677aa7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -873,10 +873,12 @@ class SparkSession private( def clearTags(): Unit = userDefinedToRealTagsMap.clear() /** - * Interrupt all operations of this session that are currently running. + * Request to interrupt all currently running operations of this session. * - * @return - * sequence of Job IDs of interrupted operations. + * @note This method will wait up to 60 seconds for the interruption request to be issued. + + * @return Sequence of job IDs requested to be interrupted. + * @since 4.0.0 */ def interruptAll(): Seq[String] = { @@ -885,12 +887,12 @@ class SparkSession private( } /** - * Interrupt all operations of this session with the given operation tag. + * Request to interrupt all currently running operations of this session with the given operation + * tag. * - * @return - * sequence of Job IDs of interrupted operations. + * @note This method will wait up to 60 seconds for the interruption request to be issued. * - * @since 4.0.0 + * @return Sequence of job IDs requested to be interrupted. */ def interruptTag(tag: String): Seq[String] = { val realTag = userDefinedToRealTagsMap.get(tag) @@ -901,10 +903,12 @@ class SparkSession private( } /** - * Interrupt an operation of this session with the given Job ID. + * Request to interrupt an operation of this session, given its job ID. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. * - * @return - * sequence of Job IDs of interrupted operations. + * @return The job ID requested to be interrupted, as a single-element sequence, or an empty + * sequence if the operation is not started by this session. * * @since 4.0.0 */ @@ -917,7 +921,7 @@ class SparkSession private( shouldCancelJob = _.getSparkSessionUUID.contains(sessionUUID)) ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq case None => - throw new IllegalArgumentException("jobId must be a number.") + throw new IllegalArgumentException("jobId must be a number in string form.") } } From f2ad16316c75571059a927221db215996d02a9b4 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 27 Aug 2024 11:06:28 +0200 Subject: [PATCH 10/28] no mention of spark session in core --- .../main/scala/org/apache/spark/SparkContext.scala | 6 ------ .../org/apache/spark/scheduler/ActiveJob.scala | 5 +---- .../scala/org/apache/spark/sql/SparkSession.scala | 13 +++++++++---- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5b22ab5dfac5..d5a20e11509d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -826,11 +826,6 @@ class SparkContext(config: SparkConf) extends Logging { def getLocalProperty(key: String): String = Option(localProperties.get).map(_.getProperty(key)).orNull - /** Set the UUID of the Spark session that starts the current job. */ - def setSparkSessionUUID(uuid: String): Unit = { - setLocalProperty(SparkContext.SPARK_SESSION_UUID, uuid) - } - /** Set a human readable description of the current job. */ def setJobDescription(value: String): Unit = { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) @@ -3153,7 +3148,6 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" private[spark] val SPARK_JOB_TAGS = "spark.job.tags" private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool" - private[spark] val SPARK_SESSION_UUID = "spark.sparkSession.uuid" private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index a191320bf054..9876668194a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import org.apache.spark.{JobArtifactSet, SparkContext} +import org.apache.spark.JobArtifactSet import org.apache.spark.util.CallSite /** @@ -63,7 +63,4 @@ private[spark] class ActiveJob( val finished = Array.fill[Boolean](numPartitions)(false) var numFinished = 0 - - def getSparkSessionUUID: Option[String] = - Option(properties.getProperty(SparkContext.SPARK_SESSION_UUID)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index f0fb8677aa7e..30ccffe89f0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -58,6 +58,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.SparkSession.SPARK_SESSION_UUID_PROPERTY_KEY import org.apache.spark.util.{CallSite, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ @@ -882,7 +883,8 @@ class SparkSession private( * @since 4.0.0 */ def interruptAll(): Seq[String] = { - val cancelledIds = sparkContext.cancelAllJobs(_.getSparkSessionUUID.contains(sessionUUID)) + val cancelledIds = sparkContext.cancelAllJobs( + _.properties.getProperty(SPARK_SESSION_UUID_PROPERTY_KEY) == sessionUUID) ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq } @@ -918,7 +920,8 @@ class SparkSession private( val cancelledIds = sparkContext.cancelJob( jobIdToBeCancelled, "Interrupted by user", - shouldCancelJob = _.getSparkSessionUUID.contains(sessionUUID)) + shouldCancelJob = _.properties.getProperty(SPARK_SESSION_UUID_PROPERTY_KEY) == sessionUUID + ) ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq case None => throw new IllegalArgumentException("jobId must be a number in string form.") @@ -1084,6 +1087,8 @@ class SparkSession private( @Stable object SparkSession extends Logging { + private val SPARK_SESSION_UUID_PROPERTY_KEY = "spark.sparkSession.uuid" + /** * Builder for [[SparkSession]]. */ @@ -1309,7 +1314,7 @@ object SparkSession extends Logging { clearActiveSession() activeThreadSession.set(session) if (session != null) { - session.sparkContext.setSparkSessionUUID(session.sessionUUID) + session.sparkContext.setLocalProperty(SPARK_SESSION_UUID_PROPERTY_KEY, session.sessionUUID) session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) } } @@ -1324,7 +1329,7 @@ object SparkSession extends Logging { getActiveSession match { case Some(session) => if (session != null) { - session.sparkContext.setSparkSessionUUID(null) + session.sparkContext.setLocalProperty(SPARK_SESSION_UUID_PROPERTY_KEY, null) session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) } activeThreadSession.remove() From ab0068532ef8bd7c2393fc403c4626e5bee6dd60 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 27 Aug 2024 11:06:56 +0200 Subject: [PATCH 11/28] re --- sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 30ccffe89f0f..dd7ecdca1a45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -36,6 +36,7 @@ import org.apache.spark.internal.LogKeys.{CALL_SITE_LONG_FORM, CLASS_NAME} import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.SparkSession.SPARK_SESSION_UUID_PROPERTY_KEY import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ @@ -58,7 +59,6 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.sql.SparkSession.SPARK_SESSION_UUID_PROPERTY_KEY import org.apache.spark.util.{CallSite, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ From dd10f4625b6805159653900eef4f075c954470ae Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 27 Aug 2024 15:06:22 +0200 Subject: [PATCH 12/28] fix test --- sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 2 +- .../spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index dd7ecdca1a45..3d993b78601c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -133,7 +133,7 @@ class SparkSession private( /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. - * Real tag have the current session ID attached: `"tag1" -> s"spark-$sessionUUID-tag1"`. + * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. */ @transient private lazy val userDefinedToRealTagsMap: ConcurrentHashMap[String, String] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 0b0081bcaac4..3ad0d73b9d97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -171,7 +171,7 @@ class SparkSessionJobTaggingAndCancellationSuite val threeJobs = sc.dagScheduler.activeJobs assert(threeJobs.size == 3) for(ss <- Seq(sessionA, sessionB, sessionC)) { - val job = threeJobs.filter(_.getSparkSessionUUID.getOrElse("") == ss.sessionUUID) + val job = threeJobs.filter(_.properties.get("spark.sparkSession.uuid") == ss.sessionUUID) assert(job.size == 1) val tags = job.head.properties.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] .split(SparkContext.SPARK_JOB_TAGS_SEP) From bc9b76d71f18ae939b56306ce4454eed98793ae7 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 28 Aug 2024 12:23:30 +0200 Subject: [PATCH 13/28] revert some changes --- .../scala/org/apache/spark/SparkContext.scala | 93 ++++++++----------- .../apache/spark/scheduler/DAGScheduler.scala | 83 ++++++----------- .../spark/scheduler/DAGSchedulerEvent.scala | 9 +- .../apache/spark/scheduler/JobWaiter.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 49 +++++----- 6 files changed, 88 insertions(+), 150 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d5a20e11509d..8136bed0038d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -917,6 +917,13 @@ class SparkContext(config: SparkConf) extends Logging { setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags) } + /** + * Add a tag to be assigned to all the jobs started by this thread. The tag will be prefixed with + * an internal prefix to avoid conflicts with user tags. + */ + private[spark] def addInternalJobTag(tag: String): Unit = + addJobTag(s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}$tag") + /** * Remove a tag previously added to be assigned to all the jobs started by this thread. * Noop if such a tag was not added earlier. @@ -936,6 +943,16 @@ class SparkContext(config: SparkConf) extends Logging { } } + /** + * Get the tags that are currently set to be assigned to all the jobs started by this thread. + */ + private[spark] def getInternalJobTags(): Set[String] = { + Option(getLocalProperty(SparkContext.SPARK_JOB_TAGS)) + .map(_.split(SparkContext.SPARK_JOB_TAGS_SEP).toSet) + .getOrElse(Set()) + .filter(_.startsWith(SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX)) // only internal tags + } + /** * Get the tags that are currently set to be assigned to all the jobs started by this thread. * @@ -945,7 +962,8 @@ class SparkContext(config: SparkConf) extends Logging { Option(getLocalProperty(SparkContext.SPARK_JOB_TAGS)) .map(_.split(SparkContext.SPARK_JOB_TAGS_SEP).toSet) .getOrElse(Set()) - .filter(!_.isEmpty) // empty string tag should not happen, but be defensive + .filterNot(_.startsWith(SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX)) // exclude internal tags + .filter(_.nonEmpty) // empty string tag should not happen, but be defensive } /** @@ -954,7 +972,14 @@ class SparkContext(config: SparkConf) extends Logging { * @since 3.5.0 */ def clearJobTags(): Unit = { - setLocalProperty(SparkContext.SPARK_JOB_TAGS, null) + val internalTags = getInternalJobTags() + if (internalTags.isEmpty) { + setLocalProperty(SparkContext.SPARK_JOB_TAGS, null) + } else { + setLocalProperty( + SparkContext.SPARK_JOB_TAGS, + internalTags.mkString(SparkContext.SPARK_JOB_TAGS_SEP)) + } } /** @@ -2690,19 +2715,14 @@ class SparkContext(config: SparkConf) extends Logging { * * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. * @param reason reason for cancellation - * @param shouldCancelJob Callback function to be called with the job ID of each job that matches - * the given tag. If the function returns true, the job will be cancelled. * @return A future that will be completed with the set of job IDs that were cancelled. */ - private[spark] def cancelJobsWithTag( - tag: String, - reason: String, - shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { + private[spark] def cancelJobsWithTagWithFuture(tag: String, reason: String): Future[Set[Int]] = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() val cancelledJobs = Promise[Set[Int]]() - dagScheduler.cancelJobsWithTag(tag, Option(reason), Some(shouldCancelJob), Some(cancelledJobs)) + dagScheduler.cancelJobsWithTag(tag, Option(reason), Some(cancelledJobs)) cancelledJobs.future } @@ -2717,11 +2737,7 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String, reason: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag( - tag, - Option(reason), - shouldCancelJob = None, - jobsToBeCancelled = None) + dagScheduler.cancelJobsWithTag(tag, Option(reason), jobsToBeCancelled = None) } /** @@ -2734,51 +2750,13 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag( - tag, - reason = None, - shouldCancelJob = None, - jobsToBeCancelled = None) - } - - /** - * Cancel all jobs that have been scheduled or are running. - * - * @param shouldCancelJob Callback function to be called with the job ID of each job that matches - * the given tag. If the function returns true, the job will be cancelled. - * @return A future that will be completed with the set of job IDs that were cancelled. - */ - private[spark] def cancelAllJobs(shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { - assertNotStopped() - - val cancelledJobs = Promise[Set[Int]]() - dagScheduler.cancelAllJobs(Some(shouldCancelJob), Some(cancelledJobs)) - cancelledJobs.future + dagScheduler.cancelJobsWithTag(tag, reason = None, jobsToBeCancelled = None) } /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs(): Unit = { assertNotStopped() - dagScheduler.cancelAllJobs(shouldCancelJob = None, jobsToBeCancelled = None) - } - - /** - * Cancel a given job if it's scheduled or running. - * - * @param jobId the job ID to cancel - * @param reason reason for cancellation - * @param shouldCancelJob Callback function to be called with the job ID of each job that matches - * the given tag. If the function returns true, the job will be cancelled. - * @return A future that will be completed with the set of job IDs that were cancelled. - * @note Throws `InterruptedException` if the cancel message cannot be sent - */ - private[spark] def cancelJob( - jobId: Int, - reason: String, - shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = { - val cancelledJobs = Promise[Set[Int]]() - dagScheduler.cancelJob(jobId, Option(reason), Some(shouldCancelJob), Some(cancelledJobs)) - cancelledJobs.future + dagScheduler.cancelAllJobs() } /** @@ -2789,7 +2767,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int, reason: String): Unit = { - dagScheduler.cancelJob(jobId, Option(reason), shouldCancelJob = None, jobsToBeCancelled = None) + dagScheduler.cancelJob(jobId, Option(reason)) } /** @@ -2799,7 +2777,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int): Unit = { - dagScheduler.cancelJob(jobId, reason = None, shouldCancelJob = None, jobsToBeCancelled = None) + dagScheduler.cancelJob(jobId, reason = None) } /** @@ -3161,6 +3139,9 @@ object SparkContext extends Logging { /** Separator of tags in SPARK_JOB_TAGS property */ private[spark] val SPARK_JOB_TAGS_SEP = "," + /** Prefix to mark a tag to be visible internally, not by users */ + private[spark] val SPARK_JOB_TAGS_INTERNAL_PREFIX = "__internal_tag__" + // Same rules apply to Spark Connect execution tags, see ExecuteHolder.throwIfInvalidTag private[spark] def throwIfInvalidTag(tag: String) = { if (tag == null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8cbcef23398d..8aa5e41b1535 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1100,13 +1100,9 @@ private[spark] class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob( - jobId: Int, - reason: Option[String], - shouldCancelJob: Option[ActiveJob => Boolean], - jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { + def cancelJob(jobId: Int, reason: Option[String]): Unit = { logInfo(log"Asked to cancel job ${MDC(JOB_ID, jobId)}") - eventProcessLoop.post(JobCancelled(jobId, reason, shouldCancelJob, jobsToBeCancelled)) + eventProcessLoop.post(JobCancelled(jobId, reason)) } /** @@ -1123,36 +1119,31 @@ private[spark] class DAGScheduler( * Cancel all jobs with a given tag. * * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. - * @param reason reason for cancellation - * @param shouldCancelJob Callback function to be called with the job ID of each job that matches - * the given tag. If the function returns true, the job will be cancelled. + * @param reason reason for cancellation. + * @param jobsToBeCancelled a promise to be completed with the set of job ids that are cancelled. */ def cancelJobsWithTag( tag: String, reason: Option[String], - shouldCancelJob: Option[ActiveJob => Boolean], jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason, shouldCancelJob, jobsToBeCancelled)) + eventProcessLoop.post(JobTagCancelled(tag, reason, jobsToBeCancelled)) } /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs( - shouldCancelJob: Option[ActiveJob => Boolean], - jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { - eventProcessLoop.post(AllJobsCancelled(shouldCancelJob, jobsToBeCancelled)) + def cancelAllJobs(): Unit = { + eventProcessLoop.post(AllJobsCancelled) } - def doCancelAllJobs( - shouldCancelJob: Option[ActiveJob => Boolean], - jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { + private[scheduler] def doCancelAllJobs(): Unit = { // Cancel all running jobs. - val jobIds = runningStages.map(_.firstJobId).filter(shouldCancelJobId(_, shouldCancelJob)) - jobsToBeCancelled.foreach(_.success(jobIds.toSet)) - jobIds.foreach(doJobCancellation(_, Option("as part of cancellation of all jobs"))) + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, + Option("as part of cancellation of all jobs"))) + activeJobs.clear() // These should already be empty by this point, + jobIdToActiveJob.clear() // but just in case we lost track of some jobs... } /** @@ -1248,13 +1239,12 @@ private[spark] class DAGScheduler( } val jobIds = activeInGroup.map(_.jobId) val updatedReason = reason.getOrElse("part of cancelled job group %s".format(groupId)) - jobIds.foreach(doJobCancellation(_, Option(updatedReason))) + jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) } private[scheduler] def handleJobTagCancelled( tag: String, reason: Option[String], - shouldCancelJob: Option[ActiveJob => Boolean], jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { // Cancel all jobs belonging that have this tag. // First finds all active jobs with this group id, and then kill stages for them. @@ -1264,10 +1254,10 @@ private[spark] class DAGScheduler( .split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag) } }.map(_.jobId) + jobsToBeCancelled.foreach(_.success(jobIds.toSet)) + val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - val idsToBeCancelled = jobIds.filter(shouldCancelJobId(_, shouldCancelJob)) - jobsToBeCancelled.foreach(_.success(idsToBeCancelled.toSet)) - idsToBeCancelled.foreach(doJobCancellation(_, Option(updatedReason))) + jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -2824,40 +2814,19 @@ private[spark] class DAGScheduler( case None => s"because Stage $stageId was cancelled" } - doJobCancellation(jobId, Option(reasonStr)) + handleJobCancellation(jobId, Option(reasonStr)) } case None => logInfo(log"No active jobs to kill for Stage ${MDC(STAGE_ID, stageId)}") } } - private[scheduler] def handleJobCancellation( - jobId: Int, - reason: Option[String], - shouldCancelJob: Option[ActiveJob => Boolean], - jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { - val shouldCancel = shouldCancelJobId(jobId, shouldCancelJob) - if (shouldCancel) { - jobsToBeCancelled.foreach(_.success(Set(jobId))) - doJobCancellation(jobId, reason) - } else { - jobsToBeCancelled.foreach(_.success(Set.empty)) - } - } - - private def shouldCancelJobId( - jobId: Int, - shouldCancelJob: Option[ActiveJob => Boolean]): Boolean = { - shouldCancelJob.forall(_(jobIdToActiveJob(jobId))) - } - - private def doJobCancellation(jobId: Int, reason: Option[String]): Unit = { + private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]): Unit = { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - val activeJob = jobIdToActiveJob(jobId) failJobAndIndependentStages( - job = activeJob, + job = jobIdToActiveJob(jobId), error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null) ) } @@ -3151,17 +3120,17 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case StageCancelled(stageId, reason) => dagScheduler.handleStageCancellation(stageId, reason) - case JobCancelled(jobId, reason, shouldCancelJob, jobsToBeCancelled) => - dagScheduler.handleJobCancellation(jobId, reason, shouldCancelJob, jobsToBeCancelled) + case JobCancelled(jobId, reason) => + dagScheduler.handleJobCancellation(jobId, reason) case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) - case JobTagCancelled(tag, reason, shouldCancelJob, jobsToBeCancelled) => - dagScheduler.handleJobTagCancelled(tag, reason, shouldCancelJob, jobsToBeCancelled) + case JobTagCancelled(tag, reason, jobsToBeCancelled) => + dagScheduler.handleJobTagCancelled(tag, reason, jobsToBeCancelled) - case AllJobsCancelled(shouldCancelJob, jobsToBeCancelled) => - dagScheduler.doCancelAllJobs(shouldCancelJob, jobsToBeCancelled) + case AllJobsCancelled => + dagScheduler.doCancelAllJobs() case ExecutorAdded(execId, host) => dagScheduler.handleExecutorAdded(execId, host) @@ -3217,7 +3186,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler override def onError(e: Throwable): Unit = { logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) try { - dagScheduler.doCancelAllJobs(shouldCancelJob = None, jobsToBeCancelled = None) + dagScheduler.doCancelAllJobs() } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 947dc633ac5f..0fd86ee12be4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -62,9 +62,7 @@ private[scheduler] case class StageCancelled( private[scheduler] case class JobCancelled( jobId: Int, - reason: Option[String], - shouldCancelJob: Option[ActiveJob => Boolean], - jobsToBeCancelled: Option[Promise[Set[Int]]]) + reason: Option[String]) extends DAGSchedulerEvent private[scheduler] case class JobGroupCancelled( @@ -76,12 +74,9 @@ private[scheduler] case class JobGroupCancelled( private[scheduler] case class JobTagCancelled( tagName: String, reason: Option[String], - shouldCancelJob: Option[ActiveJob => Boolean], jobsToBeCancelled: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent -private[scheduler] case class AllJobsCancelled( - shouldCancelJob: Option[ActiveJob => Boolean], - jobsToBeCancelled: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent +private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 7df0a74317c7..bfd675938703 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -50,7 +50,7 @@ private[spark] class JobWaiter[T]( * all the tasks belonging to this job, it will fail this job with a SparkException. */ def cancel(reason: Option[String]): Unit = { - dagScheduler.cancelJob(jobId, reason, shouldCancelJob = None, jobsToBeCancelled = None) + dagScheduler.cancelJob(jobId, reason) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 447004cd7e64..5346fa6cdfd1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -534,7 +534,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti /** Sends JobCancelled to the DAG scheduler. */ private def cancel(jobId: Int): Unit = { - runEvent(JobCancelled(jobId, reason = None, shouldCancelJob = None, jobsToBeCancelled = None)) + runEvent(JobCancelled(jobId, reason = None)) } /** Make some tasks in task set success and check results. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3d993b78601c..f25beca1d4d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -36,7 +36,6 @@ import org.apache.spark.internal.LogKeys.{CALL_SITE_LONG_FORM, CLASS_NAME} import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} -import org.apache.spark.sql.SparkSession.SPARK_SESSION_UUID_PROPERTY_KEY import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ @@ -95,7 +94,7 @@ class SparkSession private( @transient private val parentSessionState: Option[SessionState], @transient private[sql] val extensions: SparkSessionExtensions, @transient private[sql] val initialSessionOptions: Map[String, String], - @transient private val parentUserDefinedToRealTagsMap: Map[String, String]) + @transient private val parentManagedJobTags: Map[String, String]) extends Serializable with Closeable with Logging { self => // The call site where this SparkSession was constructed. @@ -116,7 +115,7 @@ class SparkSession private( parentSessionState = None, SparkSession.applyExtensions(sc, new SparkSessionExtensions), initialSessionOptions.asScala.toMap, - parentUserDefinedToRealTagsMap = Map.empty) + parentManagedJobTags = Map.empty) } private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]()) @@ -131,13 +130,16 @@ class SparkSession private( .getOrElse(SQLConf.getFallbackConf) }) + /** Tag to mark all jobs owned by this session. */ + private lazy val sessionJobTag = s"spark-session-$sessionUUID" + /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. */ @transient - private lazy val userDefinedToRealTagsMap: ConcurrentHashMap[String, String] = - new ConcurrentHashMap(parentUserDefinedToRealTagsMap.asJava) + private lazy val managedJobTags: ConcurrentHashMap[String, String] = + new ConcurrentHashMap(parentManagedJobTags.asJava) /** * The version of Spark on which this application is running. @@ -295,7 +297,7 @@ class SparkSession private( parentSessionState = None, extensions, initialSessionOptions, - parentUserDefinedToRealTagsMap = Map.empty) + parentManagedJobTags = Map.empty) } /** @@ -317,9 +319,9 @@ class SparkSession private( Some(sessionState), extensions, Map.empty, - userDefinedToRealTagsMap.asScala.toMap) + managedJobTags.asScala.toMap) result.sessionState // force copy of SessionState - result.userDefinedToRealTagsMap // force copy of userDefinedToRealTagsMap + result.managedJobTags // force copy of userDefinedToRealTagsMap result } @@ -841,7 +843,7 @@ class SparkSession private( */ def addTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - userDefinedToRealTagsMap.put(tag, s"spark-session-$sessionUUID-$tag") + managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag") } /** @@ -855,7 +857,7 @@ class SparkSession private( */ def removeTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - userDefinedToRealTagsMap.remove(tag) + managedJobTags.remove(tag) } /** @@ -864,14 +866,14 @@ class SparkSession private( * * @since 4.0.0 */ - def getTags(): Set[String] = userDefinedToRealTagsMap.keys().asScala.toSet + def getTags(): Set[String] = managedJobTags.keys().asScala.toSet /** * Clear the current thread's operation tags. * * @since 4.0.0 */ - def clearTags(): Unit = userDefinedToRealTagsMap.clear() + def clearTags(): Unit = managedJobTags.clear() /** * Request to interrupt all currently running operations of this session. @@ -882,11 +884,7 @@ class SparkSession private( * @since 4.0.0 */ - def interruptAll(): Seq[String] = { - val cancelledIds = sparkContext.cancelAllJobs( - _.properties.getProperty(SPARK_SESSION_UUID_PROPERTY_KEY) == sessionUUID) - ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq - } + def interruptAll(): Seq[String] = interruptTag(sessionJobTag) /** * Request to interrupt all currently running operations of this session with the given operation @@ -897,10 +895,10 @@ class SparkSession private( * @return Sequence of job IDs requested to be interrupted. */ def interruptTag(tag: String): Seq[String] = { - val realTag = userDefinedToRealTagsMap.get(tag) + val realTag = managedJobTags.get(tag) if (realTag == null) return Seq.empty - val cancelledIds = sparkContext.cancelJobsWithTag(realTag, "Interrupted by user", _ => true) + val cancelledIds = sparkContext.cancelJobsWithTagWithFuture(realTag, "Interrupted by user") ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq } @@ -917,12 +915,7 @@ class SparkSession private( def interruptOperation(jobId: String): Seq[String] = { scala.util.Try(jobId.toInt).toOption match { case Some(jobIdToBeCancelled) => - val cancelledIds = sparkContext.cancelJob( - jobIdToBeCancelled, - "Interrupted by user", - shouldCancelJob = _.properties.getProperty(SPARK_SESSION_UUID_PROPERTY_KEY) == sessionUUID - ) - ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq + Seq() case None => throw new IllegalArgumentException("jobId must be a number in string form.") } @@ -1286,7 +1279,7 @@ object SparkSession extends Logging { parentSessionState = None, extensions, initialSessionOptions = options.toMap, - parentUserDefinedToRealTagsMap = Map.empty) + parentManagedJobTags = Map.empty) setDefaultSession(session) setActiveSession(session) registerContextListener(sparkContext) @@ -1315,7 +1308,7 @@ object SparkSession extends Logging { activeThreadSession.set(session) if (session != null) { session.sparkContext.setLocalProperty(SPARK_SESSION_UUID_PROPERTY_KEY, session.sessionUUID) - session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) + session.managedJobTags.values().asScala.foreach(session.sparkContext.addJobTag) } } @@ -1330,7 +1323,7 @@ object SparkSession extends Logging { case Some(session) => if (session != null) { session.sparkContext.setLocalProperty(SPARK_SESSION_UUID_PROPERTY_KEY, null) - session.userDefinedToRealTagsMap.values().asScala.foreach(session.sparkContext.addJobTag) + session.managedJobTags.values().asScala.foreach(session.sparkContext.addJobTag) } activeThreadSession.remove() case None => // do nothing From 865681081b9b7f53a0fe71e0c5f6427bc60e9211 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 28 Aug 2024 12:27:13 +0200 Subject: [PATCH 14/28] undo --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 17 +---------------- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8136bed0038d..579bb6cdf4a0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2777,7 +2777,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int): Unit = { - dagScheduler.cancelJob(jobId, reason = None) + dagScheduler.cancelJob(jobId, None) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5346fa6cdfd1..978ceb16b376 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -534,7 +534,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti /** Sends JobCancelled to the DAG scheduler. */ private def cancel(jobId: Int): Unit = { - runEvent(JobCancelled(jobId, reason = None)) + runEvent(JobCancelled(jobId, None)) } /** Make some tasks in task set success and check results. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index f25beca1d4d6..9fe05bb3655d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1080,8 +1080,6 @@ class SparkSession private( @Stable object SparkSession extends Logging { - private val SPARK_SESSION_UUID_PROPERTY_KEY = "spark.sparkSession.uuid" - /** * Builder for [[SparkSession]]. */ @@ -1304,12 +1302,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def setActiveSession(session: SparkSession): Unit = { - clearActiveSession() activeThreadSession.set(session) - if (session != null) { - session.sparkContext.setLocalProperty(SPARK_SESSION_UUID_PROPERTY_KEY, session.sessionUUID) - session.managedJobTags.values().asScala.foreach(session.sparkContext.addJobTag) - } } /** @@ -1319,15 +1312,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def clearActiveSession(): Unit = { - getActiveSession match { - case Some(session) => - if (session != null) { - session.sparkContext.setLocalProperty(SPARK_SESSION_UUID_PROPERTY_KEY, null) - session.managedJobTags.values().asScala.foreach(session.sparkContext.addJobTag) - } - activeThreadSession.remove() - case None => // do nothing - } + activeThreadSession.remove() } /** From 1dfafad25a86b952a7fd6c1ef7f63af6e9e2c392 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 28 Aug 2024 16:08:12 +0200 Subject: [PATCH 15/28] wip --- .../scala/org/apache/spark/SparkContext.scala | 12 +- .../org/apache/spark/sql/SparkSession.scala | 4 +- .../spark/sql/execution/SQLExecution.scala | 200 ++++++++++-------- ...essionJobTaggingAndCancellationSuite.scala | 14 +- 4 files changed, 129 insertions(+), 101 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 579bb6cdf4a0..a3b92a954e0e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -912,7 +912,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def addJobTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - val existingTags = getJobTags() + val existingTags = getJobTags() ++ getInternalJobTags() val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags) } @@ -934,7 +934,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def removeJobTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - val existingTags = getJobTags() + val existingTags = getJobTags() ++ getInternalJobTags() val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) if (newTags.isEmpty) { clearJobTags() @@ -943,6 +943,12 @@ class SparkContext(config: SparkConf) extends Logging { } } + /** + * Remove an internal tag previously added to be assigned to all the jobs started by this thread. + */ + private[spark] def removeInternalJobTag(tag: String): Unit = + removeJobTag(s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}$tag") + /** * Get the tags that are currently set to be assigned to all the jobs started by this thread. */ @@ -3140,7 +3146,7 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_TAGS_SEP = "," /** Prefix to mark a tag to be visible internally, not by users */ - private[spark] val SPARK_JOB_TAGS_INTERNAL_PREFIX = "__internal_tag__" + private[spark] val SPARK_JOB_TAGS_INTERNAL_PREFIX = "~~spark~internal~tag~~" // Same rules apply to Spark Connect execution tags, see ExecuteHolder.throwIfInvalidTag private[spark] def throwIfInvalidTag(tag: String) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 9fe05bb3655d..0725955f6ca1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -131,14 +131,14 @@ class SparkSession private( }) /** Tag to mark all jobs owned by this session. */ - private lazy val sessionJobTag = s"spark-session-$sessionUUID" + private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID" /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. */ @transient - private lazy val managedJobTags: ConcurrentHashMap[String, String] = + private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = new ConcurrentHashMap(parentManagedJobTags.asJava) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 58fff2d4a1a2..ab93133648c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -116,92 +116,94 @@ object SQLExecution extends Logging { val redactedConfigs = sparkSession.sessionState.conf.redactOptions(modifiedConfigs) withSQLConfPropagated(sparkSession) { - var ex: Option[Throwable] = None - var isExecutedPlanAvailable = false - val startTime = System.nanoTime() - val startEvent = SparkListenerSQLExecutionStart( - executionId = executionId, - rootExecutionId = Some(rootExecutionId), - description = desc, - details = callSite.longForm, - physicalPlanDescription = "", - sparkPlanInfo = SparkPlanInfo.EMPTY, - time = System.currentTimeMillis(), - modifiedConfigs = redactedConfigs, - jobTags = sc.getJobTags(), - jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) - ) - try { - body match { - case Left(e) => - sc.listenerBus.post(startEvent) + withSessionTagsApplied(sparkSession) { + var ex: Option[Throwable] = None + var isExecutedPlanAvailable = false + val startTime = System.nanoTime() + val startEvent = SparkListenerSQLExecutionStart( + executionId = executionId, + rootExecutionId = Some(rootExecutionId), + description = desc, + details = callSite.longForm, + physicalPlanDescription = "", + sparkPlanInfo = SparkPlanInfo.EMPTY, + time = System.currentTimeMillis(), + modifiedConfigs = redactedConfigs, + jobTags = sc.getJobTags() ++ sc.getInternalJobTags(), + jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) + ) + try { + body match { + case Left(e) => + sc.listenerBus.post(startEvent) + throw e + case Right(f) => + val planDescriptionMode = + ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) + val planDesc = queryExecution.explainString(planDescriptionMode) + val planInfo = try { + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan) + } catch { + case NonFatal(e) => + logDebug("Failed to generate SparkPlanInfo", e) + // If the queryExecution already failed before this, we are not able to generate + // the the plan info, so we use and empty graphviz node to make the UI happy + SparkPlanInfo.EMPTY + } + sc.listenerBus.post( + startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo)) + isExecutedPlanAvailable = true + f() + } + } catch { + case e: Throwable => + ex = Some(e) throw e - case Right(f) => - val planDescriptionMode = - ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) - val planDesc = queryExecution.explainString(planDescriptionMode) - val planInfo = try { - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan) - } catch { - case NonFatal(e) => - logDebug("Failed to generate SparkPlanInfo", e) - // If the queryExecution already failed before this, we are not able to generate - // the the plan info, so we use and empty graphviz node to make the UI happy - SparkPlanInfo.EMPTY - } - sc.listenerBus.post( - startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo)) - isExecutedPlanAvailable = true - f() - } - } catch { - case e: Throwable => - ex = Some(e) - throw e - } finally { - val endTime = System.nanoTime() - val errorMessage = ex.map { - case e: SparkThrowable => - SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) - case e => - Utils.exceptionString(e) - } - if (queryExecution.shuffleCleanupMode != DoNotCleanup - && isExecutedPlanAvailable) { - val shuffleIds = queryExecution.executedPlan match { - case ae: AdaptiveSparkPlanExec => - ae.context.shuffleIds.asScala.keys - case _ => - Iterable.empty + } finally { + val endTime = System.nanoTime() + val errorMessage = ex.map { + case e: SparkThrowable => + SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) + case e => + Utils.exceptionString(e) } - shuffleIds.foreach { shuffleId => - queryExecution.shuffleCleanupMode match { - case RemoveShuffleFiles => - // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister - // the shuffle on MapOutputTracker, so that stage retries would be triggered. - // Set blocking to Utils.isTesting to deflake unit tests. - sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting) - case SkipMigration => - SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId) - case _ => // this should not happen + if (queryExecution.shuffleCleanupMode != DoNotCleanup + && isExecutedPlanAvailable) { + val shuffleIds = queryExecution.executedPlan match { + case ae: AdaptiveSparkPlanExec => + ae.context.shuffleIds.asScala.keys + case _ => + Iterable.empty + } + shuffleIds.foreach { shuffleId => + queryExecution.shuffleCleanupMode match { + case RemoveShuffleFiles => + // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister + // the shuffle on MapOutputTracker, so that stage retries would be triggered. + // Set blocking to Utils.isTesting to deflake unit tests. + sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting) + case SkipMigration => + SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId) + case _ => // this should not happen + } } } + val event = SparkListenerSQLExecutionEnd( + executionId, + System.currentTimeMillis(), + // Use empty string to indicate no error, as None may mean events generated by old + // versions of Spark. + errorMessage.orElse(Some(""))) + // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the + // `name` parameter. The `ExecutionListenerManager` only watches SQL executions with + // name. We can specify the execution name in more places in the future, so that + // `QueryExecutionListener` can track more cases. + event.executionName = name + event.duration = endTime - startTime + event.qe = queryExecution + event.executionFailure = ex + sc.listenerBus.post(event) } - val event = SparkListenerSQLExecutionEnd( - executionId, - System.currentTimeMillis(), - // Use empty string to indicate no error, as None may mean events generated by old - // versions of Spark. - errorMessage.orElse(Some(""))) - // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name` - // parameter. The `ExecutionListenerManager` only watches SQL executions with name. We - // can specify the execution name in more places in the future, so that - // `QueryExecutionListener` can track more cases. - event.executionName = name - event.duration = endTime - startTime - event.qe = queryExecution - event.executionFailure = ex - sc.listenerBus.post(event) } } } finally { @@ -238,15 +240,30 @@ object SQLExecution extends Logging { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) withSQLConfPropagated(sparkSession) { - try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) - body - } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + withSessionTagsApplied(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } } } } + private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = { + sparkSession.sparkContext.addInternalJobTag(sparkSession.sessionJobTag) + val userTags = sparkSession.managedJobTags.values().asScala.toSeq + userTags.foreach(sparkSession.sparkContext.addJobTag) + + try { + block + } finally { + sparkSession.sparkContext.removeInternalJobTag(sparkSession.sessionJobTag) + userTags.foreach(sparkSession.sparkContext.removeJobTag) + } + } + /** * Wrap an action with specified SQL configs. These configs will be propagated to the executor * side via job local properties. @@ -286,10 +303,13 @@ object SQLExecution extends Logging { val originalSession = SparkSession.getActiveSession val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) - sc.setLocalProperties(localProps) - val res = body - // reset active session and local props. - sc.setLocalProperties(originalLocalProps) + val res = withSessionTagsApplied(activeSession) { + sc.setLocalProperties(localProps) + val res = body + // reset active session and local props. + sc.setLocalProperties(originalLocalProps) + res + } if (originalSession.nonEmpty) { SparkSession.setActiveSession(originalSession.get) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 3ad0d73b9d97..dc16d960b2ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -170,17 +170,19 @@ class SparkSessionJobTaggingAndCancellationSuite // Tags are applied val threeJobs = sc.dagScheduler.activeJobs assert(threeJobs.size == 3) - for(ss <- Seq(sessionA, sessionB, sessionC)) { - val job = threeJobs.filter(_.properties.get("spark.sparkSession.uuid") == ss.sessionUUID) + for (ss <- Seq(sessionA, sessionB, sessionC)) { + val job = threeJobs.filter(_.properties.get(SparkContext.SPARK_JOB_TAGS) + .asInstanceOf[String].contains(ss.sessionUUID)) assert(job.size == 1) val tags = job.head.properties.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] .split(SparkContext.SPARK_JOB_TAGS_SEP) - assert(tags.forall(_.contains(s"spark-session-${ss.sessionUUID}-"))) + assert(tags.forall(_.contains(s"spark-session-${ss.sessionUUID}"))) val userTags = tags.map(_.replace(s"spark-session-${ss.sessionUUID}-", "")) + val sessionTag = s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}${ss.sessionJobTag}" ss match { - case s if s == sessionA => assert(userTags.toSet == Set("one")) - case s if s == sessionB => assert(userTags.toSet == Set("one", "two")) - case s if s == sessionC => assert(userTags.toSet == Set("boo")) + case s if s == sessionA => assert(userTags.toSet == Set(sessionTag, "one")) + case s if s == sessionB => assert(userTags.toSet == Set(sessionTag, "one", "two")) + case s if s == sessionC => assert(userTags.toSet == Set(sessionTag, "boo")) } } From 1d4d5cc30325acd422686f5d4860f489aa3207e8 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Thu, 29 Aug 2024 12:04:21 +0200 Subject: [PATCH 16/28] . --- .../scala/org/apache/spark/SparkContext.scala | 16 ++++---- .../apache/spark/scheduler/DAGScheduler.scala | 27 +++++++------- .../spark/scheduler/DAGSchedulerEvent.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 37 +++++++++++++------ .../spark/sql/execution/SQLExecution.scala | 7 +++- ...essionJobTaggingAndCancellationSuite.scala | 27 +++++++++----- 6 files changed, 73 insertions(+), 43 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a3b92a954e0e..233f6aa045c9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -978,7 +978,7 @@ class SparkContext(config: SparkConf) extends Logging { * @since 3.5.0 */ def clearJobTags(): Unit = { - val internalTags = getInternalJobTags() + val internalTags = getInternalJobTags() // exclude internal tags if (internalTags.isEmpty) { setLocalProperty(SparkContext.SPARK_JOB_TAGS, null) } else { @@ -2721,14 +2721,16 @@ class SparkContext(config: SparkConf) extends Logging { * * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. * @param reason reason for cancellation - * @return A future that will be completed with the set of job IDs that were cancelled. + * @return A future that will be completed with the set of job tags that were cancelled. */ - private[spark] def cancelJobsWithTagWithFuture(tag: String, reason: String): Future[Set[Int]] = { + private[spark] def cancelJobsWithTagWithFuture( + tag: String, + reason: String): Future[Seq[ActiveJob]] = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - val cancelledJobs = Promise[Set[Int]]() - dagScheduler.cancelJobsWithTag(tag, Option(reason), Some(cancelledJobs)) + val cancelledJobs = Promise[Seq[ActiveJob]]() + dagScheduler.cancelJobsWithTag(tag, Some(reason), Some(cancelledJobs)) cancelledJobs.future } @@ -2743,7 +2745,7 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String, reason: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, Option(reason), jobsToBeCancelled = None) + dagScheduler.cancelJobsWithTag(tag, Option(reason), cancelledJobs = None) } /** @@ -2756,7 +2758,7 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, reason = None, jobsToBeCancelled = None) + dagScheduler.cancelJobsWithTag(tag, reason = None, cancelledJobs = None) } /** Cancel all jobs that have been scheduled or are running. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8aa5e41b1535..4200850845e4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -145,7 +145,7 @@ private[spark] class DAGScheduler( private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) private[scheduler] val nextJobId = new AtomicInteger(0) - private[spark] def numTotalJobs: Int = nextJobId.get() + private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] @@ -1120,15 +1120,15 @@ private[spark] class DAGScheduler( * * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. * @param reason reason for cancellation. - * @param jobsToBeCancelled a promise to be completed with the set of job ids that are cancelled. + * @param cancelledJobs a promise to be completed with operation IDs being cancelled. */ def cancelJobsWithTag( tag: String, reason: Option[String], - jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { + cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason, jobsToBeCancelled)) + eventProcessLoop.post(JobTagCancelled(tag, reason, cancelledJobs)) } /** @@ -1245,19 +1245,20 @@ private[spark] class DAGScheduler( private[scheduler] def handleJobTagCancelled( tag: String, reason: Option[String], - jobsToBeCancelled: Option[Promise[Set[Int]]]): Unit = { - // Cancel all jobs belonging that have this tag. + cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = { + // Cancel all jobs that have all provided tags. // First finds all active jobs with this group id, and then kill stages for them. - val jobIds = activeJobs.filter { activeJob => + val jobsToBeCancelled = activeJobs.filter { activeJob => Option(activeJob.properties).exists { properties => Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("") .split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag) } - }.map(_.jobId) - jobsToBeCancelled.foreach(_.success(jobIds.toSet)) + } + cancelledJobs.foreach(_.success(jobsToBeCancelled.toSeq)) - val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) + val updatedReason = + reason.getOrElse("part of cancelled job tags %s".format(tag)) + jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_, Option(updatedReason))) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -3126,8 +3127,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) - case JobTagCancelled(tag, reason, jobsToBeCancelled) => - dagScheduler.handleJobTagCancelled(tag, reason, jobsToBeCancelled) + case JobTagCancelled(tag, reason, cancelledJobs) => + dagScheduler.handleJobTagCancelled(tag, reason, cancelledJobs) case AllJobsCancelled => dagScheduler.doCancelAllJobs() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 0fd86ee12be4..8932d2ef323b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -74,7 +74,7 @@ private[scheduler] case class JobGroupCancelled( private[scheduler] case class JobTagCancelled( tagName: String, reason: Option[String], - jobsToBeCancelled: Option[Promise[Set[Int]]]) extends DAGSchedulerEvent + cancelledJobs: Option[Promise[Seq[ActiveJob]]]) extends DAGSchedulerEvent private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 0725955f6ca1..90c4bee6b371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -29,6 +29,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, SparkException, TaskContext} + import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.{Logging, MDC} @@ -880,11 +881,12 @@ class SparkSession private( * * @note This method will wait up to 60 seconds for the interruption request to be issued. - * @return Sequence of job IDs requested to be interrupted. + * @return Sequence of SQL execution IDs requested to be interrupted. * @since 4.0.0 */ - def interruptAll(): Seq[String] = interruptTag(sessionJobTag) + def interruptAll(): Seq[String] = + doInterruptTag(sessionJobTag, "as part of cancellation of all jobs", tagIsInternal = true) /** * Request to interrupt all currently running operations of this session with the given operation @@ -892,32 +894,43 @@ class SparkSession private( * * @note This method will wait up to 60 seconds for the interruption request to be issued. * - * @return Sequence of job IDs requested to be interrupted. + * @return Sequence of SQL execution IDs requested to be interrupted. */ def interruptTag(tag: String): Seq[String] = { val realTag = managedJobTags.get(tag) if (realTag == null) return Seq.empty + doInterruptTag(realTag, s"part of cancelled job tags $tag", tagIsInternal = false) + } + + private def doInterruptTag( + tag: String, + reason: String, + tagIsInternal: Boolean): Seq[String] = { + val realTag = if (tagIsInternal) s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}$tag" else tag + val cancelledTags = + sparkContext.cancelJobsWithTagWithFuture(realTag, reason) - val cancelledIds = sparkContext.cancelJobsWithTagWithFuture(realTag, "Interrupted by user") - ThreadUtils.awaitResult(cancelledIds, 60.seconds).map(_.toString).toSeq + ThreadUtils.awaitResult(cancelledTags, 60.seconds) + .flatMap(job => Option(job.properties.getProperty(SQLExecution.EXECUTION_ROOT_ID_KEY))) } /** - * Request to interrupt an operation of this session, given its job ID. + * Request to interrupt an operation of this session, given its SQL execution ID. * * @note This method will wait up to 60 seconds for the interruption request to be issued. * - * @return The job ID requested to be interrupted, as a single-element sequence, or an empty + * @return The execution ID requested to be interrupted, as a single-element sequence, or an empty * sequence if the operation is not started by this session. * * @since 4.0.0 */ - def interruptOperation(jobId: String): Seq[String] = { - scala.util.Try(jobId.toInt).toOption match { - case Some(jobIdToBeCancelled) => - Seq() + def interruptOperation(executionId: String): Seq[String] = { + scala.util.Try(executionId.toLong).toOption match { + case Some(executionIdToBeCancelled) => + val tagToBeCancelled = SQLExecution.executionIdJobTag(this, executionIdToBeCancelled) + doInterruptTag(tagToBeCancelled, reason = "", tagIsInternal = true) case None => - throw new IllegalArgumentException("jobId must be a number in string form.") + throw new IllegalArgumentException("executionId must be a number in string form.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index ab93133648c1..c3b1102c7373 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -44,7 +44,7 @@ object SQLExecution extends Logging { private def nextExecutionId: Long = _nextExecutionId.getAndIncrement - private val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]() + private[sql] val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]() def getQueryExecution(executionId: Long): QueryExecution = { executionIdToQueryExecution.get(executionId) @@ -52,6 +52,9 @@ object SQLExecution extends Logging { private val testing = sys.props.contains(IS_TESTING.key) + private[sql] def executionIdJobTag(session: SparkSession, id: Long) = + s"${session.sessionJobTag}-execution-root-id-$id" + private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { val sc = sparkSession.sparkContext // only throw an exception during tests. a missing execution ID should not fail a job. @@ -82,6 +85,7 @@ object SQLExecution extends Logging { // And for the root execution, rootExecutionId == executionId. if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == null) { sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, executionId.toString) + sc.addInternalJobTag(executionIdJobTag(sparkSession, executionId)) } val rootExecutionId = sc.getLocalProperty(EXECUTION_ROOT_ID_KEY).toLong executionIdToQueryExecution.put(executionId, queryExecution) @@ -213,6 +217,7 @@ object SQLExecution extends Logging { // The current execution is the root execution if rootExecutionId == executionId. if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == executionId.toString) { sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, null) + sc.removeInternalJobTag(executionIdJobTag(sparkSession, executionId)) } sc.setLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL, originalInterruptOnCancel) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index dc16d960b2ee..d455d7148271 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkExcept import org.apache.spark.internal.config.EXECUTOR_ALLOW_SPARK_CONTEXT import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.util.ExecutionListenerBus @@ -176,13 +177,21 @@ class SparkSessionJobTaggingAndCancellationSuite assert(job.size == 1) val tags = job.head.properties.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] .split(SparkContext.SPARK_JOB_TAGS_SEP) - assert(tags.forall(_.contains(s"spark-session-${ss.sessionUUID}"))) - val userTags = tags.map(_.replace(s"spark-session-${ss.sessionUUID}-", "")) + val sessionTag = s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}${ss.sessionJobTag}" + val executionRootIdTag = SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX + + SQLExecution.executionIdJobTag( + ss, + job.head.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) + val userTagsPrefix = s"spark-session-${ss.sessionUUID}-" + ss match { - case s if s == sessionA => assert(userTags.toSet == Set(sessionTag, "one")) - case s if s == sessionB => assert(userTags.toSet == Set(sessionTag, "one", "two")) - case s if s == sessionC => assert(userTags.toSet == Set(sessionTag, "boo")) + case s if s == sessionA => assert(tags.toSet == Set( + sessionTag, executionRootIdTag, s"${userTagsPrefix}one")) + case s if s == sessionB => assert(tags.toSet == Set( + sessionTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two")) + case s if s == sessionC => assert(tags.toSet == Set( + sessionTag, executionRootIdTag, s"${userTagsPrefix}boo")) } } @@ -190,19 +199,19 @@ class SparkSessionJobTaggingAndCancellationSuite assert(globalSession.interruptAll().isEmpty) assert(globalSession.interruptTag("one").isEmpty) assert(globalSession.interruptTag("two").isEmpty) - for (i <- 0 until globalSession.sparkContext.dagScheduler.numTotalJobs) { + for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) { assert(globalSession.interruptOperation(i.toString).isEmpty) } assert(jobEnded.intValue == 0) // One job cancelled - for (i <- 0 until globalSession.sparkContext.dagScheduler.numTotalJobs) { + for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) { sessionC.interruptOperation(i.toString) } val eC = intercept[SparkException] { ThreadUtils.awaitResult(jobC, 1.minute) }.getCause - assert(eC.getMessage contains "Interrupted") + assert(eC.getMessage contains "cancelled") assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 1) @@ -211,7 +220,7 @@ class SparkSessionJobTaggingAndCancellationSuite val eA = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 1.minute) }.getCause - assert(eA.getMessage contains "Interrupted") + assert(eA.getMessage contains "cancelled job tags one") assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 2) From d1208c48103997c55d46bbc6907f7902fba71eb6 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Thu, 29 Aug 2024 14:12:48 +0200 Subject: [PATCH 17/28] revert unnessesary changes and fix tests --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 ++-- .../org/apache/spark/sql/execution/SQLExecutionSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4200850845e4..5aafbdc2a08e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -168,7 +168,7 @@ private[spark] class DAGScheduler( // Stages that must be resubmitted due to fetch failures private[scheduler] val failedStages = new HashSet[Stage] - private[spark] val activeJobs = new HashSet[ActiveJob] + private[scheduler] val activeJobs = new HashSet[ActiveJob] // Job groups that are cancelled with `cancelFutureJobs` as true, with at most // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in @@ -1254,7 +1254,7 @@ private[spark] class DAGScheduler( .split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag) } } - cancelledJobs.foreach(_.success(jobsToBeCancelled.toSeq)) + cancelledJobs.map(_.success(jobsToBeCancelled.toSeq)) val updatedReason = reason.getOrElse("part of cancelled job tags %s".format(tag)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index 94d33731b6de..059a4c9b8376 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -228,7 +228,7 @@ class SQLExecutionSuite extends SparkFunSuite with SQLConfHelper { spark.range(1).collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(jobTags.contains(jobTag)) + assert(jobTags.get.contains(jobTag)) assert(sqlJobTags.contains(jobTag)) } finally { spark.sparkContext.removeJobTag(jobTag) From 13342cf92693cf093744e4b11526972bcf9bdbf5 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Thu, 29 Aug 2024 14:34:23 +0200 Subject: [PATCH 18/28] comment --- core/src/main/scala/org/apache/spark/SparkContext.scala | 5 +++-- .../src/main/scala/org/apache/spark/sql/SparkSession.scala | 7 +++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 233f6aa045c9..d3a13b71879b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2720,8 +2720,9 @@ class SparkContext(config: SparkConf) extends Logging { * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`. * * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. - * @param reason reason for cancellation - * @return A future that will be completed with the set of job tags that were cancelled. + * @param reason reason for cancellation. + * @return A future with [[ActiveJob]]s, allowing extraction of information such as Job ID and + * tags. */ private[spark] def cancelJobsWithTagWithFuture( tag: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 9e38b78ab334..c33d38e389ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -861,8 +861,8 @@ class SparkSession private( } /** - * Get the tags that are currently set to be assigned to all the operations started by this - * thread. + * Get the operation tags that are currently set to be assigned to all the operations started by + * this session. * * @since 4.0.0 */ @@ -888,8 +888,7 @@ class SparkSession private( doInterruptTag(sessionJobTag, "as part of cancellation of all jobs", tagIsInternal = true) /** - * Request to interrupt all currently running operations of this session with the given operation - * tag. + * Request to interrupt all currently running operations of this session with the given job tag. * * @note This method will wait up to 60 seconds for the interruption request to be issued. * From 38799895d7788753087e9e984f88a70a7a5f1c12 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Thu, 29 Aug 2024 15:36:23 +0200 Subject: [PATCH 19/28] oh no --- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- ...essionJobTaggingAndCancellationSuite.scala | 553 +----------------- 2 files changed, 2 insertions(+), 553 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5aafbdc2a08e..8bd5a139dc85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -168,7 +168,7 @@ private[spark] class DAGScheduler( // Stages that must be resubmitted due to fetch failures private[scheduler] val failedStages = new HashSet[Stage] - private[scheduler] val activeJobs = new HashSet[ActiveJob] + private[spark] val activeJobs = new HashSet[ActiveJob] // Job groups that are cancelled with `cancelFutureJobs` as true, with at most // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index d455d7148271..dd3f3a1577c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -23,19 +23,12 @@ import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, Future} import scala.jdk.CollectionConverters._ -import org.apache.hadoop.fs.Path -import org.apache.logging.log4j.Level import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} -import org.apache.spark.internal.config.EXECUTOR_ALLOW_SPARK_CONTEXT -import org.apache.spark.internal.config.UI.UI_ENABLED +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.StaticSQLConf._ -import org.apache.spark.sql.util.ExecutionListenerBus import org.apache.spark.tags.ExtendedSQLTest import org.apache.spark.util.ThreadUtils @@ -236,548 +229,4 @@ class SparkSessionJobTaggingAndCancellationSuite fpool.shutdownNow() } } - - - test("SPARK-34087: Fix memory leak of ExecutionListenerBus") { - val spark = SparkSession.builder() - .master("local") - .getOrCreate() - - @inline def listenersNum(): Int = { - spark.sparkContext - .listenerBus - .listeners - .asScala - .count(_.isInstanceOf[ExecutionListenerBus]) - } - - (1 to 10).foreach { _ => - spark.cloneSession() - SparkSession.clearActiveSession() - } - - eventually(timeout(10.seconds), interval(1.seconds)) { - System.gc() - // After GC, the number of ExecutionListenerBus should be less than 11 (we created 11 - // SparkSessions in total). - // Since GC can't 100% guarantee all out-of-referenced objects be cleaned at one time, - // here, we check at least one listener is cleaned up to prove the mechanism works. - assert(listenersNum() < 11) - } - } - - test("create with config options and propagate them to SparkContext and SparkSession") { - val session = SparkSession.builder() - .master("local") - .config(UI_ENABLED.key, value = false) - .config("some-config", "v2") - .getOrCreate() - assert(session.sparkContext.conf.get("some-config") == "v2") - assert(session.conf.get("some-config") == "v2") - } - - test("use global default session") { - val session = SparkSession.builder().master("local").getOrCreate() - assert(SparkSession.builder().getOrCreate() == session) - } - - test("sets default and active session") { - assert(SparkSession.getDefaultSession == None) - assert(SparkSession.getActiveSession == None) - val session = SparkSession.builder().master("local").getOrCreate() - assert(SparkSession.getDefaultSession == Some(session)) - assert(SparkSession.getActiveSession == Some(session)) - } - - test("get active or default session") { - val session = SparkSession.builder().master("local").getOrCreate() - assert(SparkSession.active == session) - SparkSession.clearActiveSession() - assert(SparkSession.active == session) - SparkSession.clearDefaultSession() - intercept[SparkException](SparkSession.active) - session.stop() - } - - test("config options are propagated to existing SparkSession") { - val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() - assert(session1.conf.get("spark-config1") == "a") - val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate() - assert(session1 == session2) - assert(session1.conf.get("spark-config1") == "b") - } - - test("use session from active thread session and propagate config options") { - val defaultSession = SparkSession.builder().master("local").getOrCreate() - val activeSession = defaultSession.newSession() - SparkSession.setActiveSession(activeSession) - val session = SparkSession.builder().config("spark-config2", "a").getOrCreate() - - assert(activeSession != defaultSession) - assert(session == activeSession) - assert(session.conf.get("spark-config2") == "a") - assert(session.sessionState.conf == SQLConf.get) - assert(SQLConf.get.getConfString("spark-config2") == "a") - SparkSession.clearActiveSession() - - assert(SparkSession.builder().getOrCreate() == defaultSession) - } - - test("create a new session if the default session has been stopped") { - val defaultSession = SparkSession.builder().master("local").getOrCreate() - SparkSession.setDefaultSession(defaultSession) - defaultSession.stop() - val newSession = SparkSession.builder().master("local").getOrCreate() - assert(newSession != defaultSession) - } - - test("create a new session if the active thread session has been stopped") { - val activeSession = SparkSession.builder().master("local").getOrCreate() - SparkSession.setActiveSession(activeSession) - activeSession.stop() - val newSession = SparkSession.builder().master("local").getOrCreate() - assert(newSession != activeSession) - } - - test("create SparkContext first then SparkSession") { - val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") - val sparkContext2 = new SparkContext(conf) - val session = SparkSession.builder().config("key2", "value2").getOrCreate() - assert(session.conf.get("key1") == "value1") - assert(session.conf.get("key2") == "value2") - assert(session.sparkContext == sparkContext2) - // We won't update conf for existing `SparkContext` - assert(!sparkContext2.conf.contains("key2")) - assert(sparkContext2.conf.get("key1") == "value1") - } - - test("create SparkContext first then pass context to SparkSession") { - val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") - val newSC = new SparkContext(conf) - val session = SparkSession.builder().sparkContext(newSC).config("key2", "value2").getOrCreate() - assert(session.conf.get("key1") == "value1") - assert(session.conf.get("key2") == "value2") - assert(session.sparkContext == newSC) - assert(session.sparkContext.conf.get("key1") == "value1") - // If the created sparkContext is passed through the Builder's API sparkContext, - // the conf of this sparkContext will not contain the conf set through the API config. - assert(!session.sparkContext.conf.contains("key2")) - assert(session.sparkContext.conf.get("spark.app.name") == "test") - } - - test("SPARK-15887: hive-site.xml should be loaded") { - val session = SparkSession.builder().master("local").getOrCreate() - assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") - assert(session.sparkContext.hadoopConfiguration.get("hive.in.test") == "true") - } - - test("SPARK-15991: Set global Hadoop conf") { - val session = SparkSession.builder().master("local").getOrCreate() - val mySpecialKey = "my.special.key.15991" - val mySpecialValue = "msv" - try { - session.sparkContext.hadoopConfiguration.set(mySpecialKey, mySpecialValue) - assert(session.sessionState.newHadoopConf().get(mySpecialKey) == mySpecialValue) - } finally { - session.sparkContext.hadoopConfiguration.unset(mySpecialKey) - } - } - - test("SPARK-31234: RESET command will not change static sql configs and " + - "spark context conf values in SessionState") { - val session = SparkSession.builder() - .master("local") - .config(GLOBAL_TEMP_DATABASE.key, value = "globalTempDB-SPARK-31234") - .config("spark.app.name", "test-app-SPARK-31234") - .getOrCreate() - - assert(session.conf.get("spark.app.name") === "test-app-SPARK-31234") - assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234") - session.sql("RESET") - assert(session.conf.get("spark.app.name") === "test-app-SPARK-31234") - assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234") - } - - test("SPARK-31354: SparkContext only register one SparkSession ApplicationEnd listener") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test-app-SPARK-31354-1") - val context = new SparkContext(conf) - SparkSession - .builder() - .sparkContext(context) - .master("local") - .getOrCreate() - val postFirstCreation = context.listenerBus.listeners.size() - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - - SparkSession - .builder() - .sparkContext(context) - .master("local") - .getOrCreate() - val postSecondCreation = context.listenerBus.listeners.size() - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - assert(postFirstCreation == postSecondCreation) - } - - test("SPARK-31532: should not propagate static sql configs to the existing" + - " active/default SparkSession") { - val session = SparkSession.builder() - .master("local") - .config(GLOBAL_TEMP_DATABASE.key, value = "globalTempDB-SPARK-31532") - .config("spark.app.name", "test-app-SPARK-31532") - .getOrCreate() - // do not propagate static sql configs to the existing active session - val session1 = SparkSession - .builder() - .config(GLOBAL_TEMP_DATABASE.key, "globalTempDB-SPARK-31532-1") - .getOrCreate() - assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") - assert(session1.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") - - // do not propagate static sql configs to the existing default session - SparkSession.clearActiveSession() - val session2 = SparkSession - .builder() - .config(WAREHOUSE_PATH.key, "SPARK-31532-db") - .config(GLOBAL_TEMP_DATABASE.key, value = "globalTempDB-SPARK-31532-2") - .getOrCreate() - - assert(!session.conf.get(WAREHOUSE_PATH).contains("SPARK-31532-db")) - assert(session.conf.get(WAREHOUSE_PATH) === session2.conf.get(WAREHOUSE_PATH)) - assert(session2.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") - } - - test("SPARK-31532: propagate static sql configs if no existing SparkSession") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test-app-SPARK-31532-2") - .set(GLOBAL_TEMP_DATABASE.key, "globaltempdb-spark-31532") - .set(WAREHOUSE_PATH.key, "SPARK-31532-db") - SparkContext.getOrCreate(conf) - - // propagate static sql configs if no existing session - val session = SparkSession - .builder() - .config(GLOBAL_TEMP_DATABASE.key, "globalTempDB-SPARK-31532-2") - .config(WAREHOUSE_PATH.key, "SPARK-31532-db-2") - .getOrCreate() - assert(session.conf.get("spark.app.name") === "test-app-SPARK-31532-2") - assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532-2") - assert(session.conf.get(WAREHOUSE_PATH) contains "SPARK-31532-db-2") - } - - test("SPARK-32062: reset listenerRegistered in SparkSession") { - (1 to 2).foreach { i => - val conf = new SparkConf() - .setMaster("local") - .setAppName(s"test-SPARK-32062-$i") - val context = new SparkContext(conf) - val beforeListenerSize = context.listenerBus.listeners.size() - SparkSession - .builder() - .sparkContext(context) - .getOrCreate() - val afterListenerSize = context.listenerBus.listeners.size() - assert(beforeListenerSize + 1 == afterListenerSize) - context.stop() - } - } - - test("SPARK-32160: Disallow to create SparkSession in executors") { - val session = SparkSession.builder().master("local-cluster[3, 1, 1024]").getOrCreate() - - val error = intercept[SparkException] { - session.range(1).foreach { v => - SparkSession.builder().master("local").getOrCreate() - () - } - }.getMessage() - - assert(error.contains("SparkSession should only be created and accessed on the driver.")) - } - - test("SPARK-32160: Allow to create SparkSession in executors if the config is set") { - val session = SparkSession.builder().master("local-cluster[3, 1, 1024]").getOrCreate() - - session.range(1).foreach { v => - SparkSession.builder().master("local") - .config(EXECUTOR_ALLOW_SPARK_CONTEXT.key, true).getOrCreate().stop() - () - } - } - - test("SPARK-32991: Use conf in shared state as the original configuration for RESET") { - val wh = "spark.sql.warehouse.dir" - val td = "spark.sql.globalTempDatabase" - val custom = "spark.sql.custom" - - val conf = new SparkConf() - .setMaster("local") - .setAppName("SPARK-32991") - .set(wh, "./data1") - .set(td, "bob") - - val sc = new SparkContext(conf) - - val spark = SparkSession.builder() - .config(wh, "./data2") - .config(td, "alice") - .config(custom, "kyao") - .getOrCreate() - - // When creating the first session like above, we will update the shared spark conf to the - // newly specified values - val sharedWH = spark.sharedState.conf.get(wh) - val sharedTD = spark.sharedState.conf.get(td) - assert(sharedWH contains "data2", - "The warehouse dir in shared state should be determined by the 1st created spark session") - assert(sharedTD === "alice", - "Static sql configs in shared state should be determined by the 1st created spark session") - assert(spark.sharedState.conf.getOption(custom).isEmpty, - "Dynamic sql configs is session specific") - - assert(spark.conf.get(wh) contains sharedWH, - "The warehouse dir in session conf and shared state conf should be consistent") - assert(spark.conf.get(td) === sharedTD, - "Static sql configs in session conf and shared state conf should be consistent") - assert(spark.conf.get(custom) === "kyao", "Dynamic sql configs is session specific") - - spark.sql("RESET") - - assert(spark.conf.get(wh) contains sharedWH, - "The warehouse dir in shared state should be respect after RESET") - assert(spark.conf.get(td) === sharedTD, - "Static sql configs in shared state should be respect after RESET") - assert(spark.conf.get(custom) === "kyao", - "Dynamic sql configs in session initial map should be respect after RESET") - - val spark2 = SparkSession.builder() - .config(wh, "./data3") - .config(custom, "kyaoo").getOrCreate() - assert(spark2.conf.get(wh) contains sharedWH) - assert(spark2.conf.get(td) === sharedTD) - assert(spark2.conf.get(custom) === "kyaoo") - } - - test("SPARK-32991: RESET should work properly with multi threads") { - val wh = "spark.sql.warehouse.dir" - val td = "spark.sql.globalTempDatabase" - val custom = "spark.sql.custom" - val spark = ThreadUtils.runInNewThread("new session 0", false) { - SparkSession.builder() - .master("local") - .config(wh, "./data0") - .config(td, "bob") - .config(custom, "c0") - .getOrCreate() - } - - spark.sql(s"SET $custom=c1") - assert(spark.conf.get(custom) === "c1") - spark.sql("RESET") - assert(spark.conf.get(wh) contains "data0", - "The warehouse dir in shared state should be respect after RESET") - assert(spark.conf.get(td) === "bob", - "Static sql configs in shared state should be respect after RESET") - assert(spark.conf.get(custom) === "c0", - "Dynamic sql configs in shared state should be respect after RESET") - - val spark1 = ThreadUtils.runInNewThread("new session 1", false) { - SparkSession.builder().getOrCreate() - } - - assert(spark === spark1) - - // TODO: SPARK-33718: After clear sessions, the SharedState will be unreachable, then all - // the new static will take effect. - SparkSession.clearDefaultSession() - val spark2 = ThreadUtils.runInNewThread("new session 2", false) { - SparkSession.builder() - .master("local") - .config(wh, "./data1") - .config(td, "alice") - .config(custom, "c2") - .getOrCreate() - } - - assert(spark2 !== spark) - spark2.sql(s"SET $custom=c1") - assert(spark2.conf.get(custom) === "c1") - spark2.sql("RESET") - assert(spark2.conf.get(wh) contains "data1") - assert(spark2.conf.get(td) === "alice") - assert(spark2.conf.get(custom) === "c2") - - } - - test("SPARK-33944: warning setting hive.metastore.warehouse.dir using session options") { - val msg = "Not allowing to set hive.metastore.warehouse.dir in SparkSession's options" - val logAppender = new LogAppender(msg) - withLogAppender(logAppender) { - SparkSession.builder() - .master("local") - .config("hive.metastore.warehouse.dir", "any") - .getOrCreate() - .sharedState - } - assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) - } - - test("SPARK-33944: no warning setting spark.sql.warehouse.dir using session options") { - val msg = "Not allowing to set hive.metastore.warehouse.dir in SparkSession's options" - val logAppender = new LogAppender(msg) - withLogAppender(logAppender) { - SparkSession.builder() - .master("local") - .config("spark.sql.warehouse.dir", "any") - .getOrCreate() - .sharedState - } - assert(!logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) - } - - Seq(".", "..", "dir0", "dir0/dir1", "/dir0/dir1", "./dir0").foreach { pathStr => - test(s"SPARK-34558: warehouse path ($pathStr) should be qualified for spark/hadoop conf") { - val path = new Path(pathStr) - val conf = new SparkConf().set(WAREHOUSE_PATH, pathStr) - val session = SparkSession.builder() - .master("local") - .config(conf) - .getOrCreate() - val hadoopConf = session.sessionState.newHadoopConf() - val expected = path.getFileSystem(hadoopConf).makeQualified(path).toString - // session related configs - assert(hadoopConf.get("hive.metastore.warehouse.dir") === expected) - assert(session.conf.get(WAREHOUSE_PATH) === expected) - assert(session.sessionState.conf.warehousePath === expected) - - // shared configs - assert(session.sharedState.conf.get(WAREHOUSE_PATH) === expected) - assert(session.sharedState.hadoopConf.get("hive.metastore.warehouse.dir") === expected) - - // spark context configs - assert(session.sparkContext.conf.get(WAREHOUSE_PATH) === expected) - assert(session.sparkContext.hadoopConfiguration.get("hive.metastore.warehouse.dir") === - expected) - } - } - - test("SPARK-34558: Create a working SparkSession with a broken FileSystem") { - val msg = "Cannot qualify the warehouse path, leaving it unqualified" - val logAppender = new LogAppender(msg) - withLogAppender(logAppender) { - val session = - SparkSession.builder() - .master("local") - .config(WAREHOUSE_PATH.key, "unknown:///mydir") - .getOrCreate() - session.sql("SELECT 1").collect() - } - assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) - } - - test("SPARK-37727: Show ignored configurations in debug level logs") { - // Create one existing SparkSession to check following logs. - SparkSession.builder().master("local").getOrCreate() - - val logAppender = new LogAppender - logAppender.setThreshold(Level.DEBUG) - withLogAppender(logAppender, level = Some(Level.DEBUG)) { - SparkSession.builder() - .config("spark.sql.warehouse.dir", "2") - .config("spark.abc", "abcb") - .config("spark.abcd", "abcb4") - .getOrCreate() - } - - val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) - Seq( - "Ignored static SQL configurations", - "spark.sql.warehouse.dir=2", - "Configurations that might not take effect", - "spark.abcd=abcb4", - "spark.abc=abcb").foreach { msg => - assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") - } - } - - test("SPARK-37727: Hide the same configuration already explicitly set in logs") { - // Create one existing SparkSession to check following logs. - SparkSession.builder().master("local").config("spark.abc", "abc").getOrCreate() - - val logAppender = new LogAppender - logAppender.setThreshold(Level.DEBUG) - withLogAppender(logAppender, level = Some(Level.DEBUG)) { - // Ignore logs because it's already set. - SparkSession.builder().config("spark.abc", "abc").getOrCreate() - // Show logs for only configuration newly set. - SparkSession.builder().config("spark.abc.new", "abc").getOrCreate() - // Ignore logs because it's set ^. - SparkSession.builder().config("spark.abc.new", "abc").getOrCreate() - } - - val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) - Seq( - "Using an existing Spark session; only runtime SQL configurations will take effect", - "Configurations that might not take effect", - "spark.abc.new=abc").foreach { msg => - assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") - } - - assert( - !logs.exists(_.contains("spark.abc=abc")), - s"'spark.abc=abc' existed in:\n${logs.mkString("\n")}") - } - - test("SPARK-37727: Hide runtime SQL configurations in logs") { - // Create one existing SparkSession to check following logs. - SparkSession.builder().master("local").getOrCreate() - - val logAppender = new LogAppender - logAppender.setThreshold(Level.DEBUG) - withLogAppender(logAppender, level = Some(Level.DEBUG)) { - // Ignore logs for runtime SQL configurations - SparkSession.builder().config("spark.sql.ansi.enabled", "true").getOrCreate() - // Show logs for Spark core configuration - SparkSession.builder().config("spark.buffer.size", "1234").getOrCreate() - // Show logs for custom runtime options - SparkSession.builder().config("spark.sql.source.abc", "abc").getOrCreate() - // Show logs for static SQL configurations - SparkSession.builder().config("spark.sql.warehouse.dir", "xyz").getOrCreate() - } - - val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) - Seq( - "spark.buffer.size=1234", - "spark.sql.source.abc=abc", - "spark.sql.warehouse.dir=xyz").foreach { msg => - assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") - } - - assert( - !logs.exists(_.contains("spark.sql.ansi.enabled\"")), - s"'spark.sql.ansi.enabled' existed in:\n${logs.mkString("\n")}") - } - - test("SPARK-40163: SparkSession.config(Map)") { - val map: Map[String, Any] = Map( - "string" -> "", - "boolean" -> true, - "double" -> 0.0, - "long" -> 0L - ) - - val session = SparkSession.builder() - .master("local") - .config(map) - .getOrCreate() - - for (e <- map) { - assert(session.conf.get(e._1) == e._2.toString) - } - } } From cf6437f396c71dbcad1e46fd38b45644f0cd0c62 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Fri, 30 Aug 2024 10:34:41 +0200 Subject: [PATCH 20/28] remove internal tags --- .../scala/org/apache/spark/SparkContext.scala | 42 ++----------------- .../org/apache/spark/sql/SparkSession.scala | 14 +++---- .../spark/sql/execution/SQLExecution.scala | 10 ++--- ...essionJobTaggingAndCancellationSuite.scala | 10 ++--- 4 files changed, 18 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d3a13b71879b..e664bd8bbc94 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -912,18 +912,11 @@ class SparkContext(config: SparkConf) extends Logging { */ def addJobTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - val existingTags = getJobTags() ++ getInternalJobTags() + val existingTags = getJobTags() val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags) } - /** - * Add a tag to be assigned to all the jobs started by this thread. The tag will be prefixed with - * an internal prefix to avoid conflicts with user tags. - */ - private[spark] def addInternalJobTag(tag: String): Unit = - addJobTag(s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}$tag") - /** * Remove a tag previously added to be assigned to all the jobs started by this thread. * Noop if such a tag was not added earlier. @@ -934,7 +927,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def removeJobTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - val existingTags = getJobTags() ++ getInternalJobTags() + val existingTags = getJobTags() val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) if (newTags.isEmpty) { clearJobTags() @@ -943,22 +936,6 @@ class SparkContext(config: SparkConf) extends Logging { } } - /** - * Remove an internal tag previously added to be assigned to all the jobs started by this thread. - */ - private[spark] def removeInternalJobTag(tag: String): Unit = - removeJobTag(s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}$tag") - - /** - * Get the tags that are currently set to be assigned to all the jobs started by this thread. - */ - private[spark] def getInternalJobTags(): Set[String] = { - Option(getLocalProperty(SparkContext.SPARK_JOB_TAGS)) - .map(_.split(SparkContext.SPARK_JOB_TAGS_SEP).toSet) - .getOrElse(Set()) - .filter(_.startsWith(SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX)) // only internal tags - } - /** * Get the tags that are currently set to be assigned to all the jobs started by this thread. * @@ -968,8 +945,7 @@ class SparkContext(config: SparkConf) extends Logging { Option(getLocalProperty(SparkContext.SPARK_JOB_TAGS)) .map(_.split(SparkContext.SPARK_JOB_TAGS_SEP).toSet) .getOrElse(Set()) - .filterNot(_.startsWith(SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX)) // exclude internal tags - .filter(_.nonEmpty) // empty string tag should not happen, but be defensive + .filter(!_.isEmpty) // empty string tag should not happen, but be defensive } /** @@ -978,14 +954,7 @@ class SparkContext(config: SparkConf) extends Logging { * @since 3.5.0 */ def clearJobTags(): Unit = { - val internalTags = getInternalJobTags() // exclude internal tags - if (internalTags.isEmpty) { - setLocalProperty(SparkContext.SPARK_JOB_TAGS, null) - } else { - setLocalProperty( - SparkContext.SPARK_JOB_TAGS, - internalTags.mkString(SparkContext.SPARK_JOB_TAGS_SEP)) - } + setLocalProperty(SparkContext.SPARK_JOB_TAGS, null) } /** @@ -3148,9 +3117,6 @@ object SparkContext extends Logging { /** Separator of tags in SPARK_JOB_TAGS property */ private[spark] val SPARK_JOB_TAGS_SEP = "," - /** Prefix to mark a tag to be visible internally, not by users */ - private[spark] val SPARK_JOB_TAGS_INTERNAL_PREFIX = "~~spark~internal~tag~~" - // Same rules apply to Spark Connect execution tags, see ExecuteHolder.throwIfInvalidTag private[spark] def throwIfInvalidTag(tag: String) = { if (tag == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c33d38e389ca..41bd6d4f7c8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -885,7 +885,7 @@ class SparkSession private( * @since 4.0.0 */ def interruptAll(): Seq[String] = - doInterruptTag(sessionJobTag, "as part of cancellation of all jobs", tagIsInternal = true) + doInterruptTag(sessionJobTag, "as part of cancellation of all jobs") /** * Request to interrupt all currently running operations of this session with the given job tag. @@ -897,16 +897,12 @@ class SparkSession private( def interruptTag(tag: String): Seq[String] = { val realTag = managedJobTags.get(tag) if (realTag == null) return Seq.empty - doInterruptTag(realTag, s"part of cancelled job tags $tag", tagIsInternal = false) + doInterruptTag(realTag, s"part of cancelled job tags $tag") } - private def doInterruptTag( - tag: String, - reason: String, - tagIsInternal: Boolean): Seq[String] = { - val realTag = if (tagIsInternal) s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}$tag" else tag + private def doInterruptTag(tag: String, reason: String): Seq[String] = { val cancelledTags = - sparkContext.cancelJobsWithTagWithFuture(realTag, reason) + sparkContext.cancelJobsWithTagWithFuture(tag, reason) ThreadUtils.awaitResult(cancelledTags, 60.seconds) .flatMap(job => Option(job.properties.getProperty(SQLExecution.EXECUTION_ROOT_ID_KEY))) @@ -926,7 +922,7 @@ class SparkSession private( scala.util.Try(executionId.toLong).toOption match { case Some(executionIdToBeCancelled) => val tagToBeCancelled = SQLExecution.executionIdJobTag(this, executionIdToBeCancelled) - doInterruptTag(tagToBeCancelled, reason = "", tagIsInternal = true) + doInterruptTag(tagToBeCancelled, reason = "") case None => throw new IllegalArgumentException("executionId must be a number in string form.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index c3b1102c7373..2c2104c088cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -85,7 +85,7 @@ object SQLExecution extends Logging { // And for the root execution, rootExecutionId == executionId. if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == null) { sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, executionId.toString) - sc.addInternalJobTag(executionIdJobTag(sparkSession, executionId)) + sc.addJobTag(executionIdJobTag(sparkSession, executionId)) } val rootExecutionId = sc.getLocalProperty(EXECUTION_ROOT_ID_KEY).toLong executionIdToQueryExecution.put(executionId, queryExecution) @@ -133,7 +133,7 @@ object SQLExecution extends Logging { sparkPlanInfo = SparkPlanInfo.EMPTY, time = System.currentTimeMillis(), modifiedConfigs = redactedConfigs, - jobTags = sc.getJobTags() ++ sc.getInternalJobTags(), + jobTags = sc.getJobTags(), jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) ) try { @@ -217,7 +217,7 @@ object SQLExecution extends Logging { // The current execution is the root execution if rootExecutionId == executionId. if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == executionId.toString) { sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, null) - sc.removeInternalJobTag(executionIdJobTag(sparkSession, executionId)) + sc.removeJobTag(executionIdJobTag(sparkSession, executionId)) } sc.setLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL, originalInterruptOnCancel) } @@ -257,14 +257,14 @@ object SQLExecution extends Logging { } private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = { - sparkSession.sparkContext.addInternalJobTag(sparkSession.sessionJobTag) + sparkSession.sparkContext.addJobTag(sparkSession.sessionJobTag) val userTags = sparkSession.managedJobTags.values().asScala.toSeq userTags.foreach(sparkSession.sparkContext.addJobTag) try { block } finally { - sparkSession.sparkContext.removeInternalJobTag(sparkSession.sessionJobTag) + sparkSession.sparkContext.removeJobTag(sparkSession.sessionJobTag) userTags.foreach(sparkSession.sparkContext.removeJobTag) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index dd3f3a1577c4..e528bdd7392d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -171,20 +171,18 @@ class SparkSessionJobTaggingAndCancellationSuite val tags = job.head.properties.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] .split(SparkContext.SPARK_JOB_TAGS_SEP) - val sessionTag = s"${SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX}${ss.sessionJobTag}" - val executionRootIdTag = SparkContext.SPARK_JOB_TAGS_INTERNAL_PREFIX + - SQLExecution.executionIdJobTag( + val executionRootIdTag = SQLExecution.executionIdJobTag( ss, job.head.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) val userTagsPrefix = s"spark-session-${ss.sessionUUID}-" ss match { case s if s == sessionA => assert(tags.toSet == Set( - sessionTag, executionRootIdTag, s"${userTagsPrefix}one")) + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one")) case s if s == sessionB => assert(tags.toSet == Set( - sessionTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two")) + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two")) case s if s == sessionC => assert(tags.toSet == Set( - sessionTag, executionRootIdTag, s"${userTagsPrefix}boo")) + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo")) } } From b3b7cbc1ad257c3b885956b04f521dacfa0cb878 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Fri, 30 Aug 2024 12:07:14 +0200 Subject: [PATCH 21/28] test --- ...essionJobTaggingAndCancellationSuite.scala | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index e528bdd7392d..012f98bfabaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -81,6 +81,35 @@ class SparkSessionJobTaggingAndCancellationSuite assert(session.getTags() == Set("one")) } + test("Tags set from session are prefixed with session UUID") { + sc = new SparkContext("local[2]", "test") + val session = SparkSession.builder().sparkContext(sc).getOrCreate() + import session.implicits._ + + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sem.release() + } + }) + + session.addTag("one") + Future { + session.range(1, 10000).map { i => Thread.sleep(100); i }.count() + }(ExecutionContext.global) + + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + val activeJobsFuture = + session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"), "reason") + val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head + val actualTags = activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS) + .split(SparkContext.SPARK_JOB_TAGS_SEP) + assert(actualTags.toSet == Set( + session.sessionJobTag, + s"${session.sessionJobTag}-one", + SQLExecution.executionIdJobTag(session, 0L))) + } + test("Cancellation APIs in SparkSession are isolated") { sc = new SparkContext("local[2]", "test") val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate() @@ -172,8 +201,8 @@ class SparkSessionJobTaggingAndCancellationSuite .split(SparkContext.SPARK_JOB_TAGS_SEP) val executionRootIdTag = SQLExecution.executionIdJobTag( - ss, - job.head.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) + ss, + job.head.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) val userTagsPrefix = s"spark-session-${ss.sessionUUID}-" ss match { From 7338b1da7ae87c855ba65bb23ffcb0186796024e Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Fri, 30 Aug 2024 14:34:09 +0200 Subject: [PATCH 22/28] move doc to api --- .../org/apache/spark/sql/SparkSession.scala | 70 +++---------------- .../apache/spark/sql/api/SparkSession.scala | 62 ++++++++++++++++ .../org/apache/spark/sql/SparkSession.scala | 64 ++++------------- 3 files changed, 85 insertions(+), 111 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index a4aacf2f98f8..9d7ecbc705d6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -440,7 +440,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptAll(): Seq[String] = { + override def interruptAll(): Seq[String] = { client.interruptAll().getInterruptedIdsList.asScala.toSeq } @@ -453,7 +453,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptTag(tag: String): Seq[String] = { + override def interruptTag(tag: String): Seq[String] = { client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq } @@ -466,7 +466,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptOperation(operationId: String): Seq[String] = { + override def interruptOperation(operationId: String): Seq[String] = { client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq } @@ -497,65 +497,17 @@ class SparkSession private[sql] ( SparkSession.onSessionClose(this) } - /** - * Add a tag to be assigned to all the operations started by this thread in this session. - * - * Often, a unit of execution in an application consists of multiple Spark executions. - * Application programmers can use this method to group all those jobs together and give a group - * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all - * running executions with this tag. For example: - * {{{ - * // In the main thread: - * spark.addTag("myjobs") - * spark.range(10).map(i => { Thread.sleep(10); i }).collect() - * - * // In a separate thread: - * spark.interruptTag("myjobs") - * }}} - * - * There may be multiple tags present at the same time, so different parts of application may - * use different tags to perform cancellation at different levels of granularity. - * - * @param tag - * The tag to be added. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def addTag(tag: String): Unit = { - client.addTag(tag) - } + /** @inheritdoc */ + override def addTag(tag: String): Unit = client.addTag(tag) - /** - * Remove a tag previously added to be assigned to all the operations started by this thread in - * this session. Noop if such a tag was not added earlier. - * - * @param tag - * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def removeTag(tag: String): Unit = { - client.removeTag(tag) - } + /** @inheritdoc */ + override def removeTag(tag: String): Unit = client.removeTag(tag) - /** - * Get the tags that are currently set to be assigned to all the operations started by this - * thread. - * - * @since 3.5.0 - */ - def getTags(): Set[String] = { - client.getTags() - } + /** @inheritdoc */ + override def getTags(): Set[String] = client.getTags() - /** - * Clear the current thread's operation tags. - * - * @since 3.5.0 - */ - def clearTags(): Unit = { - client.clearTags() - } + /** @inheritdoc */ + override def clearTags(): Unit = client.clearTags() /** * We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index d156aba934b6..06a6741461ac 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -367,6 +367,68 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C @scala.annotation.varargs def addArtifacts(uri: URI*): Unit + + /** + * Add a tag to be assigned to all the operations started by this thread in this session. + * + * Often, a unit of execution in an application consists of multiple Spark executions. + * Application programmers can use this method to group all those jobs together and give a group + * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all + * running executions with this tag. For example: + * {{{ + * // In the main thread: + * spark.addTag("myjobs") + * spark.range(10).map(i => { Thread.sleep(10); i }).collect() + * + * // In a separate thread: + * spark.interruptTag("myjobs") + * }}} + * + * There may be multiple tags present at the same time, so different parts of application may + * use different tags to perform cancellation at different levels of granularity. + * + * @param tag + * The tag to be added. Cannot contain ',' (comma) character or be an empty string. + * + * @since 4.0.0 + */ + def addTag(tag: String): Unit + + /** + * Remove a tag previously added to be assigned to all the operations started by this thread in + * this session. Noop if such a tag was not added earlier. + * + * @param tag + * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + * + * @since 4.0.0 + */ + def removeTag(tag: String): Unit + + /** + * Get the operation tags that are currently set to be assigned to all the operations started by + * this thread in this session. + * + * @since 4.0.0 + */ + def getTags(): Set[String] + + /** + * Clear the current thread's operation tags. + * + * @since 4.0.0 + */ + def clearTags(): Unit + + // No docstring, the meaning of return value depends on the implementation. + def interruptAll(): Seq[String] + + // No docstring, the meaning of return value depends on the implementation. + def interruptTag(tag: String): Seq[String] + + // No docstring, the meaning of return value depends on the implementation. + def interruptOperation(operationId: String): Seq[String] + /** * Executes some code block and prints to stdout the time taken to execute the block. This is * available in Scala only and is used primarily for interactive testing and debugging. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 23d7b22972f3..1a2bed65b3d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -671,63 +671,23 @@ class SparkSession private( artifactManager.addLocalArtifacts(uri.flatMap(Artifact.parseArtifacts)) } - /** - * Add a tag to be assigned to all the operations started by this thread in this session. - * - * Often, a unit of execution in an application consists of multiple Spark executions. - * Application programmers can use this method to group all those jobs together and give a group - * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all - * running executions with this tag. For example: - * {{{ - * // In the main thread: - * spark.addTag("myjobs") - * spark.range(10).map(i => { Thread.sleep(10); i }).collect() - * - * // In a separate thread: - * spark.interruptTag("myjobs") - * }}} - * - * There may be multiple tags present at the same time, so different parts of application may - * use different tags to perform cancellation at different levels of granularity. - * - * @param tag - * The tag to be added. Cannot contain ',' (comma) character or be an empty string. - * - * @since 4.0.0 - */ - def addTag(tag: String): Unit = { + /** @inheritdoc */ + override def addTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag") } - /** - * Remove a tag previously added to be assigned to all the operations started by this thread in - * this session. Noop if such a tag was not added earlier. - * - * @param tag - * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. - * - * @since 4.0.0 - */ - def removeTag(tag: String): Unit = { + /** @inheritdoc */ + override def removeTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) managedJobTags.remove(tag) } - /** - * Get the operation tags that are currently set to be assigned to all the operations started by - * this session. - * - * @since 4.0.0 - */ - def getTags(): Set[String] = managedJobTags.keys().asScala.toSet + /** @inheritdoc */ + override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet - /** - * Clear the current thread's operation tags. - * - * @since 4.0.0 - */ - def clearTags(): Unit = managedJobTags.clear() + /** @inheritdoc */ + override def clearTags(): Unit = managedJobTags.clear() /** * Request to interrupt all currently running operations of this session. @@ -738,7 +698,7 @@ class SparkSession private( * @since 4.0.0 */ - def interruptAll(): Seq[String] = + override def interruptAll(): Seq[String] = doInterruptTag(sessionJobTag, "as part of cancellation of all jobs") /** @@ -748,7 +708,7 @@ class SparkSession private( * * @return Sequence of SQL execution IDs requested to be interrupted. */ - def interruptTag(tag: String): Seq[String] = { + override def interruptTag(tag: String): Seq[String] = { val realTag = managedJobTags.get(tag) if (realTag == null) return Seq.empty doInterruptTag(realTag, s"part of cancelled job tags $tag") @@ -772,8 +732,8 @@ class SparkSession private( * * @since 4.0.0 */ - def interruptOperation(executionId: String): Seq[String] = { - scala.util.Try(executionId.toLong).toOption match { + override def interruptOperation(operationId: String): Seq[String] = { + scala.util.Try(operationId.toLong).toOption match { case Some(executionIdToBeCancelled) => val tagToBeCancelled = SQLExecution.executionIdJobTag(this, executionIdToBeCancelled) doInterruptTag(tagToBeCancelled, reason = "") From 905bf91610454313fff5a79530f0295bc78deb49 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 3 Sep 2024 14:27:42 +0200 Subject: [PATCH 23/28] fix test --- .../sql/SparkSessionJobTaggingAndCancellationSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 012f98bfabaf..4a5e5e426e68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -107,7 +107,9 @@ class SparkSessionJobTaggingAndCancellationSuite assert(actualTags.toSet == Set( session.sessionJobTag, s"${session.sessionJobTag}-one", - SQLExecution.executionIdJobTag(session, 0L))) + SQLExecution.executionIdJobTag( + session, + activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong))) } test("Cancellation APIs in SparkSession are isolated") { From 514b5e4749643189de75a9083286376654a823b0 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 10 Sep 2024 15:53:38 +0200 Subject: [PATCH 24/28] address mridulm's comments --- .../apache/spark/scheduler/DAGScheduler.scala | 5 ++--- ...kSessionJobTaggingAndCancellationSuite.scala | 17 +++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8bd5a139dc85..2c89fe7885d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -168,7 +168,7 @@ private[spark] class DAGScheduler( // Stages that must be resubmitted due to fetch failures private[scheduler] val failedStages = new HashSet[Stage] - private[spark] val activeJobs = new HashSet[ActiveJob] + private[scheduler] val activeJobs = new HashSet[ActiveJob] // Job groups that are cancelled with `cancelFutureJobs` as true, with at most // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in @@ -1254,11 +1254,10 @@ private[spark] class DAGScheduler( .split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag) } } - cancelledJobs.map(_.success(jobsToBeCancelled.toSeq)) - val updatedReason = reason.getOrElse("part of cancelled job tags %s".format(tag)) jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_, Option(updatedReason))) + cancelledJobs.map(_.success(jobsToBeCancelled.toSeq)) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 4a5e5e426e68..dda729b852c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, Future} @@ -25,8 +25,8 @@ import scala.jdk.CollectionConverters._ import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ - import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} + import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.tags.ExtendedSQLTest @@ -128,9 +128,11 @@ class SparkSessionJobTaggingAndCancellationSuite // Add a listener to release the semaphore once jobs are launched. val sem = new Semaphore(0) val jobEnded = new AtomicInteger(0) + val jobProperties: ConcurrentHashMap[Int, java.util.Properties] = new ConcurrentHashMap() sc.addSparkListener(new SparkListener { override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobProperties.put(jobStart.jobId, jobStart.properties) sem.release() } @@ -193,18 +195,17 @@ class SparkSessionJobTaggingAndCancellationSuite assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES)) // Tags are applied - val threeJobs = sc.dagScheduler.activeJobs - assert(threeJobs.size == 3) + assert(jobProperties.size == 3) for (ss <- Seq(sessionA, sessionB, sessionC)) { - val job = threeJobs.filter(_.properties.get(SparkContext.SPARK_JOB_TAGS) + val jobProperty = jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS) .asInstanceOf[String].contains(ss.sessionUUID)) - assert(job.size == 1) - val tags = job.head.properties.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] + assert(jobProperty.size == 1) + val tags = jobProperty.head.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] .split(SparkContext.SPARK_JOB_TAGS_SEP) val executionRootIdTag = SQLExecution.executionIdJobTag( ss, - job.head.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) + jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) val userTagsPrefix = s"spark-session-${ss.sessionUUID}-" ss match { From c6fb41f93e8b27c611f249444d150889c65bd325 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 10 Sep 2024 16:38:41 +0200 Subject: [PATCH 25/28] address herman's comments --- .../scala/org/apache/spark/SparkContext.scala | 32 +++++++++++++++---- .../apache/spark/sql/api/SparkSession.scala | 31 ++++++++++++++++-- .../org/apache/spark/sql/SparkSession.scala | 12 +++---- .../spark/sql/execution/SQLExecution.scala | 8 ++--- 4 files changed, 63 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e664bd8bbc94..042179d86c31 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -910,10 +910,20 @@ class SparkContext(config: SparkConf) extends Logging { * * @since 3.5.0 */ - def addJobTag(tag: String): Unit = { - SparkContext.throwIfInvalidTag(tag) + def addJobTag(tag: String): Unit = addJobTags(Set(tag)) + + /** + * Add multiple tags to be assigned to all the jobs started by this thread. + * See [[addJobTag]] for more details. + * + * @param tags The tags to be added. Cannot contain ',' (comma) character. + * + * @since 4.0.0 + */ + def addJobTags(tags: Set[String]): Unit = { + tags.foreach(SparkContext.throwIfInvalidTag) val existingTags = getJobTags() - val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) + val newTags = (existingTags ++ tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP) setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags) } @@ -925,10 +935,20 @@ class SparkContext(config: SparkConf) extends Logging { * * @since 3.5.0 */ - def removeJobTag(tag: String): Unit = { - SparkContext.throwIfInvalidTag(tag) + def removeJobTag(tag: String): Unit = removeJobTags(Set(tag)) + + /** + * Remove multiple tags to be assigned to all the jobs started by this thread. + * See [[removeJobTag]] for more details. + * + * @param tags The tags to be removed. Cannot contain ',' (comma) character. + * + * @since 4.0.0 + */ + def removeJobTags(tags: Set[String]): Unit = { + tags.foreach(SparkContext.throwIfInvalidTag) val existingTags = getJobTags() - val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) + val newTags = (existingTags -- tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP) if (newTags.isEmpty) { clearJobTags() } else { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index 06a6741461ac..a230f3abf647 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -420,13 +420,38 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C */ def clearTags(): Unit - // No docstring, the meaning of return value depends on the implementation. + /** + * Request to interrupt all currently running operations of this session. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + + * @return Sequence of operation IDs requested to be interrupted. + + * @since 4.0.0 + */ def interruptAll(): Seq[String] - // No docstring, the meaning of return value depends on the implementation. + /** + * Request to interrupt all currently running operations of this session with the given job tag. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return Sequence of operation IDs requested to be interrupted. + + * @since 4.0.0 + */ def interruptTag(tag: String): Seq[String] - // No docstring, the meaning of return value depends on the implementation. + /** + * Request to interrupt an operation of this session, given its operation ID. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return The operation ID requested to be interrupted, as a single-element sequence, or an empty + * sequence if the operation is not started by this session. + * + * @since 4.0.0 + */ def interruptOperation(operationId: String): Seq[String] /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1a2bed65b3d6..441836cffa64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -138,8 +138,9 @@ class SparkSession private( * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. */ @transient - private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = - new ConcurrentHashMap(parentManagedJobTags.asJava) + private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = { + new ConcurrentHashMap(parentManagedJobTags.asJava) + } /** @inheritdoc */ def version: String = SPARK_VERSION @@ -678,10 +679,7 @@ class SparkSession private( } /** @inheritdoc */ - override def removeTag(tag: String): Unit = { - SparkContext.throwIfInvalidTag(tag) - managedJobTags.remove(tag) - } + override def removeTag(tag: String): Unit = managedJobTags.remove(tag) /** @inheritdoc */ override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet @@ -707,6 +705,8 @@ class SparkSession private( * @note This method will wait up to 60 seconds for the interruption request to be issued. * * @return Sequence of SQL execution IDs requested to be interrupted. + + * @since 4.0.0 */ override def interruptTag(tag: String): Seq[String] = { val realTag = managedJobTags.get(tag) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c2104c088cb..3a406f4c0d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -257,15 +257,13 @@ object SQLExecution extends Logging { } private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = { - sparkSession.sparkContext.addJobTag(sparkSession.sessionJobTag) - val userTags = sparkSession.managedJobTags.values().asScala.toSeq - userTags.foreach(sparkSession.sparkContext.addJobTag) + val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag + sparkSession.sparkContext.addJobTags(allTags) try { block } finally { - sparkSession.sparkContext.removeJobTag(sparkSession.sessionJobTag) - userTags.foreach(sparkSession.sparkContext.removeJobTag) + sparkSession.sparkContext.removeJobTags(allTags) } } From 2a0292c1de90b2f966d792c88155f9d8b97b43c3 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 10 Sep 2024 16:43:52 +0200 Subject: [PATCH 26/28] address hyukjin's comment --- .../scala/org/apache/spark/sql/SparkSession.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 441836cffa64..0cd5d86f3846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -688,7 +688,9 @@ class SparkSession private( override def clearTags(): Unit = managedJobTags.clear() /** - * Request to interrupt all currently running operations of this session. + * Request to interrupt all currently running SQL operations of this session. + * + * @note Only DataFrame/SQL operations started by this session can be interrupted. * * @note This method will wait up to 60 seconds for the interruption request to be issued. @@ -700,7 +702,10 @@ class SparkSession private( doInterruptTag(sessionJobTag, "as part of cancellation of all jobs") /** - * Request to interrupt all currently running operations of this session with the given job tag. + * Request to interrupt all currently running SQL operations of this session with the given + * job tag. + * + * @note Only DataFrame/SQL operations started by this session can be interrupted. * * @note This method will wait up to 60 seconds for the interruption request to be issued. * @@ -723,7 +728,9 @@ class SparkSession private( } /** - * Request to interrupt an operation of this session, given its SQL execution ID. + * Request to interrupt a SQL operation of this session, given its SQL execution ID. + * + * @note Only DataFrame/SQL operations started by this session can be interrupted. * * @note This method will wait up to 60 seconds for the interruption request to be issued. * From a55c47c26a587434738f43eeb2df8fc4cd744bf6 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Mon, 16 Sep 2024 20:01:18 +0200 Subject: [PATCH 27/28] scalastyle --- .../spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index dda729b852c1..e9fd07ecf18b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -25,8 +25,8 @@ import scala.jdk.CollectionConverters._ import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ -import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.tags.ExtendedSQLTest From e66ba0aeca446a3b0ff903e8f05f2505e08eb44d Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Tue, 17 Sep 2024 15:17:36 +0200 Subject: [PATCH 28/28] fmt --- .../apache/spark/sql/api/SparkSession.scala | 27 +++++++++++-------- .../org/apache/spark/sql/functions.scala | 4 +-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index 898beffe37b0..4767a5e1dfd2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -390,7 +390,6 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C @scala.annotation.varargs def addArtifacts(uri: URI*): Unit - /** * Add a tag to be assigned to all the operations started by this thread in this session. * @@ -446,10 +445,12 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C /** * Request to interrupt all currently running operations of this session. * - * @note This method will wait up to 60 seconds for the interruption request to be issued. - - * @return Sequence of operation IDs requested to be interrupted. - + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return + * Sequence of operation IDs requested to be interrupted. + * * @since 4.0.0 */ def interruptAll(): Seq[String] @@ -457,10 +458,12 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C /** * Request to interrupt all currently running operations of this session with the given job tag. * - * @note This method will wait up to 60 seconds for the interruption request to be issued. + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return + * Sequence of operation IDs requested to be interrupted. * - * @return Sequence of operation IDs requested to be interrupted. - * @since 4.0.0 */ def interruptTag(tag: String): Seq[String] @@ -468,10 +471,12 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C /** * Request to interrupt an operation of this session, given its operation ID. * - * @note This method will wait up to 60 seconds for the interruption request to be issued. + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. * - * @return The operation ID requested to be interrupted, as a single-element sequence, or an empty - * sequence if the operation is not started by this session. + * @return + * The operation ID requested to be interrupted, as a single-element sequence, or an empty + * sequence if the operation is not started by this session. * * @since 4.0.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index a8b2044ba8a4..1ee86ae1a113 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -6853,8 +6853,8 @@ object functions { /** * Converts a column containing nested inputs (array/map/struct) into a variants where maps and - * structs are converted to variant objects which are unordered unlike SQL structs. Input maps can - * only have string keys. + * structs are converted to variant objects which are unordered unlike SQL structs. Input maps + * can only have string keys. * * @param col * a column with a nested schema or column name.