Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.Paths
import java.util.{ServiceLoader, UUID}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -133,14 +133,32 @@ class SparkSession private(
/** Tag to mark all jobs owned by this session. */
private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID"

/**
* A UUID that is unique on the thread level. Used by managedJobTags to make sure that the same
* use tag do not overlap in the underlying SparkContext/SQLExecution.
*/
private[sql] lazy val threadUuid = new ThreadLocal[String] {
override def initialValue(): String = UUID.randomUUID().toString
}

/**
* A map to hold the mapping from user-defined tags to the real tags attached to Jobs.
* Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`.
* Real tag have the current session ID attached:
* tag1" -> s"spark-session-$sessionUUID-thread-$threadUuid-tag1
*
*/
@transient
private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = {
new ConcurrentHashMap(parentManagedJobTags.asJava)
}
private[sql] lazy val managedJobTags = new InheritableThreadLocal[mutable.Map[String, String]] {
override def childValue(parent: mutable.Map[String, String]): mutable.Map[String, String] = {
// Note: make a clone such that changes in the parent tags aren't reflected in
// those of the children threads.
parent.clone()
}

override def initialValue(): mutable.Map[String, String] = {
mutable.Map(parentManagedJobTags.toSeq: _*)
}
}

/** @inheritdoc */
def version: String = SPARK_VERSION
Expand Down Expand Up @@ -243,10 +261,10 @@ class SparkSession private(
Some(sessionState),
extensions,
Map.empty,
managedJobTags.asScala.toMap)
managedJobTags.get().toMap)
result.sessionState // force copy of SessionState
result.sessionState.artifactManager // force copy of ArtifactManager and its resources
result.managedJobTags // force copy of userDefinedToRealTagsMap
result.managedJobTags // force copy of managedJobTags
result
}

Expand Down Expand Up @@ -550,17 +568,17 @@ class SparkSession private(
/** @inheritdoc */
override def addTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag")
managedJobTags.get().put(tag, s"spark-session-$sessionUUID-thread-${threadUuid.get()}-$tag")
}

/** @inheritdoc */
override def removeTag(tag: String): Unit = managedJobTags.remove(tag)
override def removeTag(tag: String): Unit = managedJobTags.get().remove(tag)

/** @inheritdoc */
override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet
override def getTags(): Set[String] = managedJobTags.get().keySet.toSet

/** @inheritdoc */
override def clearTags(): Unit = managedJobTags.clear()
override def clearTags(): Unit = managedJobTags.get().clear()

/**
* Request to interrupt all currently running SQL operations of this session.
Expand Down Expand Up @@ -589,9 +607,8 @@ class SparkSession private(
* @since 4.0.0
*/
override def interruptTag(tag: String): Seq[String] = {
val realTag = managedJobTags.get(tag)
if (realTag == null) return Seq.empty
doInterruptTag(realTag, s"part of cancelled job tags $tag")
val realTag = managedJobTags.get().get(tag)
realTag.map(doInterruptTag(_, s"part of cancelled job tags $tag")).getOrElse(Seq.empty)
}

private def doInterruptTag(tag: String, reason: String): Seq[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ object SQLExecution extends Logging {
}

private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = {
val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag
val allTags = sparkSession.managedJobTags.get().values.toSet + sparkSession.sessionJobTag
sparkSession.sparkContext.addJobTags(allTags)

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, Executors, Semaphore, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.{ExecutionContext, Future}
Expand Down Expand Up @@ -100,13 +100,14 @@ class SparkSessionJobTaggingAndCancellationSuite

assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
val activeJobsFuture =
session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"), "reason")
session.sparkContext.cancelJobsWithTagWithFuture(
session.managedJobTags.get()("one"), "reason")
val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head
val actualTags = activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS)
.split(SparkContext.SPARK_JOB_TAGS_SEP)
assert(actualTags.toSet == Set(
session.sessionJobTag,
s"${session.sessionJobTag}-one",
s"${session.sessionJobTag}-thread-${session.threadUuid.get()}-one",
SQLExecution.executionIdJobTag(
session,
activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)))
Expand All @@ -118,12 +119,12 @@ class SparkSessionJobTaggingAndCancellationSuite
val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate()
var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) =
(null, null, null)
var (threadUuidA, threadUuidB, threadUuidC): (String, String, String) = (null, null, null)

// global ExecutionContext has only 2 threads in Apache Spark CI
// create own thread pool for four Futures used in this test
val numThreads = 3
val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads)
val executionContext = ExecutionContext.fromExecutorService(fpool)
val threadPool = Executors.newFixedThreadPool(3)
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool)

try {
// Add a listener to release the semaphore once jobs are launched.
Expand All @@ -143,28 +144,35 @@ class SparkSessionJobTaggingAndCancellationSuite
}
})

var realTagOneForSessionA: String = null

// Note: since tags are added in the Future threads, they don't need to be cleared in between.
val jobA = Future {
sessionA = globalSession.cloneSession()
import globalSession.implicits._

threadUuidA = sessionA.threadUuid.get()
assert(sessionA.getTags() == Set())
sessionA.addTag("two")
assert(sessionA.getTags() == Set("two"))
sessionA.clearTags() // check that clearing all tags works
assert(sessionA.getTags() == Set())
sessionA.addTag("one")
realTagOneForSessionA = sessionA.managedJobTags.get()("one")
assert(realTagOneForSessionA ==
s"${sessionA.sessionJobTag}-thread-${sessionA.threadUuid.get()}-one")
assert(sessionA.getTags() == Set("one"))
try {
sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count()
} finally {
sessionA.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}
val jobB = Future {
sessionB = globalSession.cloneSession()
import globalSession.implicits._

threadUuidB = sessionB.threadUuid.get()
assert(sessionB.getTags() == Set())
sessionB.addTag("one")
sessionB.addTag("two")
Expand All @@ -176,11 +184,12 @@ class SparkSessionJobTaggingAndCancellationSuite
} finally {
sessionB.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}
val jobC = Future {
sessionC = globalSession.cloneSession()
import globalSession.implicits._

threadUuidC = sessionC.threadUuid.get()
sessionC.addTag("foo")
sessionC.removeTag("foo")
assert(sessionC.getTags() == Set()) // check that remove works removing the last tag
Expand All @@ -190,12 +199,13 @@ class SparkSessionJobTaggingAndCancellationSuite
} finally {
sessionC.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}

// Block until four jobs have started.
assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES))

// Tags are applied
def realUserTag(s: String, t: String, ta: String): String = s"spark-session-$s-thread-$t-$ta"
assert(jobProperties.size == 3)
for (ss <- Seq(sessionA, sessionB, sessionC)) {
val jobProperty = jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS)
Expand All @@ -207,15 +217,17 @@ class SparkSessionJobTaggingAndCancellationSuite
val executionRootIdTag = SQLExecution.executionIdJobTag(
ss,
jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)
val userTagsPrefix = s"spark-session-${ss.sessionUUID}-"

ss match {
case s if s == sessionA => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one"))
s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidA, "one")))
case s if s == sessionB => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two"))
s.sessionJobTag,
executionRootIdTag,
realUserTag(s.sessionUUID, threadUuidB, "one"),
realUserTag(s.sessionUUID, threadUuidB, "two")))
case s if s == sessionC => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo"))
s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidC, "boo")))
}
}

Expand All @@ -239,12 +251,14 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 1)

// Another job cancelled
assert(sessionA.interruptTag("one").size == 1)
// Another job cancelled. The next line cancels nothing because we're now in another thread
assert(sessionA.interruptTag("one").isEmpty)
// Have to cancel it via SparkContext using the real tag
sessionA.sparkContext.cancelJobsWithTagWithFuture(realTagOneForSessionA, "abc")
val eA = intercept[SparkException] {
ThreadUtils.awaitResult(jobA, 1.minute)
}.getCause
assert(eA.getMessage contains "cancelled job tags one")
assert(eA.getMessage contains "abc")
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 2)

Expand All @@ -257,7 +271,48 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 3)
} finally {
fpool.shutdownNow()
threadPool.shutdownNow()
}
}

test("Tags are isolated in multithreaded environment") {
// Custom thread pool for multi-threaded testing
val threadPool = Executors.newFixedThreadPool(2)
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool)

val session = SparkSession.builder().master("local").getOrCreate()
@volatile var output1: Set[String] = null
@volatile var output2: Set[String] = null

def tag1(): Unit = {
session.addTag("tag1")
output1 = session.getTags()
}

def tag2(): Unit = {
session.addTag("tag2")
output2 = session.getTags()
}

try {
// Run tasks in separate threads
val future1 = Future {
tag1()
}
val future2 = Future {
tag2()
}

// Wait for threads to complete
ThreadUtils.awaitResult(Future.sequence(Seq(future1, future2)), 1.minute)

// Assert outputs
assert(output1 != null)
assert(output1 == Set("tag1"))
assert(output2 != null)
assert(output2 == Set("tag2"))
} finally {
threadPool.shutdownNow()
}
}
}