Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
interruptTag test
  • Loading branch information
juliuszsompolski committed Jul 20, 2023
commit 32126e7772bd87efbabb51584bf4abdcada16e96
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/
package org.apache.spark.sql

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
import scala.concurrent.duration._
import scala.mutable
import scala.util.{Failure, Success}

import org.scalatest.concurrent.Eventually._
Expand Down Expand Up @@ -74,7 +74,7 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
assert(q1Interrupted)
assert(q2Interrupted)
}
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
assert(interrupted.distinct.length == 2, s"Interrupted operations: ${interrupted.distinct}.")
}

test("interrupt all - foreground queries, background interrupt") {
Expand Down Expand Up @@ -103,6 +103,114 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
assert(e2.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e2")
finished = true
assert(ThreadUtils.awaitResult(interruptor, 10.seconds))
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
assert(interrupted.distinct.length == 2, s"Interrupted operations: ${interrupted.distinct}.")
}

test("interrupt tag") {
val session = spark
import session.implicits._

// global ExecutionContext has only 2 threads in Apache Spark CI
// create own thread pool for four Futures used in this test
val numThreads = 4
val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads)
val executionContext = ExecutionContext.fromExecutorService(fpool)

val q1 = Future {
assert(spark.getTags() == Set())
spark.addTag("two")
assert(spark.getTags() == Set("two"))
spark.clearTags() // check that clearing all tags works
assert(spark.getTags() == Set())
spark.addTag("one")
assert(spark.getTags() == Set("one"))
try {
spark.range(10).map(n => {
Thread.sleep(30000); n
}).collect()
} finally {
spark.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
val q2 = Future {
assert(spark.getTags() == Set())
spark.addTag("one")
spark.addTag("two")
spark.addTag("one")
spark.addTag("two") // duplicates shouldn't matter
try {
spark.range(10).map(n => {
Thread.sleep(30000); n
}).collect()
} finally {
spark.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
val q3 = Future {
assert(spark.getTags() == Set())
spark.addTag("foo")
spark.removeTag("foo")
assert(spark.getTags() == Set()) // check that remove works removing the last tag
spark.addTag("two")
assert(spark.getTags() == Set("two"))
try {
spark.range(10).map(n => {
Thread.sleep(30000); n
}).collect()
} finally {
spark.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
val q4 = Future {
assert(spark.getTags() == Set())
spark.addTag("one")
spark.addTag("two")
spark.addTag("two")
assert(spark.getTags() == Set("one", "two"))
spark.removeTag("two") // check that remove works, despite duplicate add
assert(spark.getTags() == Set("one"))
try {
spark.range(10).map(n => {
Thread.sleep(30000); n
}).collect()
} finally {
spark.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
val interrupted = mutable.ListBuffer[String]()

// q2 and q3 should be cancelled
interrupted.clear()
eventually(timeout(20.seconds), interval(1.seconds)) {
val ids = spark.interruptTag("two")
interrupted ++= ids
assert(interrupted.distinct.length == 2, s"Interrupted operations: ${interrupted.distinct}.")
}
val e2 = intercept[SparkException] {
ThreadUtils.awaitResult(q2, 1.minute)
}.getCause
assert(e2.getMessage contains "OPERATION_CANCELED")
val e3 = intercept[SparkException] {
ThreadUtils.awaitResult(q3, 1.minute)
}.getCause
assert(e3.getMessage contains "OPERATION_CANCELED")
assert(interrupted.distinct.length == 2, s"Interrupted operations: ${interrupted.distinct}.")

// q1 and q4 should be cancelled
interrupted.clear()
eventually(timeout(20.seconds), interval(1.seconds)) {
val ids = spark.interruptTag("one")
interrupted ++= ids
assert(interrupted.distinct.length == 2, s"Interrupted operations: ${interrupted.distinct}.")
}
val e1 = intercept[SparkException] {
ThreadUtils.awaitResult(q1, 1.minute)
}.getCause
assert(e1.getMessage contains "OPERATION_CANCELED")
val e4 = intercept[SparkException] {
ThreadUtils.awaitResult(q4, 1.minute)
}.getCause
assert(e4.getMessage contains "OPERATION_CANCELED")
assert(interrupted.distinct.length == 2, s"Interrupted operations: ${interrupted.distinct}.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
} finally {
executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag)
executeHolder.userDefinedTags.foreach { tag =>
session.sparkContext.removeJobTag(executeHolder.tagToSparkJobTag(tag))
executeHolder.sessionHolder.session.sparkContext.removeJobTag(
executeHolder.tagToSparkJobTag(tag))
}
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private[connect] class ExecuteHolder(
* need to be combined with userId and sessionId.
*/
def tagToSparkJobTag(tag: String): String = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juliuszsompolski input tag isn't used for output, which doesn't look intended

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting!
@HyukjinKwon could you maybe piggy back changing it to maybe

    "SparkConnect_Execute_" +
      s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Tag_${tag}"

to #42120 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

"SparkConnectUserDefinedTag_" +
"SparkConnect_Tag_" +
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}"
}
}