Skip to content

Commit cc58fe3

Browse files
juliuszsompolskiHyukjinKwon
authored andcommitted
[SPARK-44422][CONNECT] Spark Connect fine grained interrupt
### What changes were proposed in this pull request? Currently, Spark Connect only allows to cancel all operations in a session by using SparkSession.interruptAll(). In this PR we are adding a mechanism to interrupt by tag (similar to SparkContext.cancelJobsWithTag), and to interrupt individual operations. Also, add the new tags to SparkListenerConnectOperationStarted. ### Why are the changes needed? Better control of query cancelation in Spark Connect ### Does this PR introduce _any_ user-facing change? Yes. New Apis in Spark Connect scala client: ``` SparkSession.addTag SparkSession.removeTag SparkSession.getTags SparkSession.clearTags SparkSession.interruptTag SparkSession.interruptOperation ``` and also `SparkResult.operationId`, to be able to get the id for `SparkSession.interruptOperation`. Python client APIs will be added in a followup PR. ### How was this patch tested? Added tests in SparkSessionE2ESuite. Closes #42009 from juliuszsompolski/sc-fine-grained-cancel. Authored-by: Juliusz Sompolski <julek@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit dda3784) Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent e8dd144 commit cc58fe3

File tree

20 files changed

+778
-108
lines changed

20 files changed

+778
-108
lines changed

common/utils/src/main/resources/error/error-classes.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,24 @@
13831383
],
13841384
"sqlState" : "22023"
13851385
},
1386+
"INVALID_HANDLE" : {
1387+
"message" : [
1388+
"The handle <handle> is invalid."
1389+
],
1390+
"subClass" : {
1391+
"ALREADY_EXISTS" : {
1392+
"message" : [
1393+
"Handle already exists."
1394+
]
1395+
},
1396+
"FORMAT" : {
1397+
"message" : [
1398+
"Handle has invalid format. Handle must an UUID string of the format '00112233-4455-6677-8899-aabbccddeeff'"
1399+
]
1400+
}
1401+
},
1402+
"sqlState" : "HY000"
1403+
},
13861404
"INVALID_HIVE_COLUMN_NAME" : {
13871405
"message" : [
13881406
"Cannot create the table <tableName> having the nested column <columnName> whose name contains invalid characters <invalidChars> in Hive metastore."

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -613,14 +613,40 @@ class SparkSession private[sql] (
613613
/**
614614
* Interrupt all operations of this session currently running on the connected server.
615615
*
616-
* TODO/WIP: Currently it will interrupt the Spark Jobs running on the server, triggered from
617-
* ExecutePlan requests. If an operation is not running a Spark Job, it becomes an noop and the
618-
* operation will continue afterwards, possibly with more Spark Jobs.
616+
* @return
617+
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
618+
* operation finishing just as it is interrupted.
619619
*
620620
* @since 3.5.0
621621
*/
622-
def interruptAll(): Unit = {
623-
client.interruptAll()
622+
def interruptAll(): Seq[String] = {
623+
client.interruptAll().getInterruptedIdsList.asScala.toSeq
624+
}
625+
626+
/**
627+
* Interrupt all operations of this session with the given operation tag.
628+
*
629+
* @return
630+
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
631+
* operation finishing just as it is interrupted.
632+
*
633+
* @since 3.5.0
634+
*/
635+
def interruptTag(tag: String): Seq[String] = {
636+
client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
637+
}
638+
639+
/**
640+
* Interrupt an operation of this session with the given operationId.
641+
*
642+
* @return
643+
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
644+
* operation finishing just as it is interrupted.
645+
*
646+
* @since 3.5.0
647+
*/
648+
def interruptOperation(operationId: String): Seq[String] = {
649+
client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
624650
}
625651

626652
/**
@@ -641,6 +667,50 @@ class SparkSession private[sql] (
641667
allocator.close()
642668
SparkSession.onSessionClose(this)
643669
}
670+
671+
/**
672+
* Add a tag to be assigned to all the operations started by this thread in this session.
673+
*
674+
* @param tag
675+
* The tag to be added. Cannot contain ',' (comma) character or be an empty string.
676+
*
677+
* @since 3.5.0
678+
*/
679+
def addTag(tag: String): Unit = {
680+
client.addTag(tag)
681+
}
682+
683+
/**
684+
* Remove a tag previously added to be assigned to all the operations started by this thread in
685+
* this session. Noop if such a tag was not added earlier.
686+
*
687+
* @param tag
688+
* The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
689+
*
690+
* @since 3.5.0
691+
*/
692+
def removeTag(tag: String): Unit = {
693+
client.removeTag(tag)
694+
}
695+
696+
/**
697+
* Get the tags that are currently set to be assigned to all the operations started by this
698+
* thread.
699+
*
700+
* @since 3.5.0
701+
*/
702+
def getTags(): Set[String] = {
703+
client.getTags()
704+
}
705+
706+
/**
707+
* Clear the current thread's operation tags.
708+
*
709+
* @since 3.5.0
710+
*/
711+
def clearTags(): Unit = {
712+
client.clearTags()
713+
}
644714
}
645715

646716
// The minimal builder needed to create a spark session.

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@ import java.net.URI
2121
import java.util.UUID
2222
import java.util.concurrent.Executor
2323

24+
import scala.collection.JavaConverters._
25+
import scala.collection.mutable
26+
2427
import com.google.protobuf.ByteString
2528
import io.grpc._
2629

2730
import org.apache.spark.connect.proto
2831
import org.apache.spark.connect.proto.UserContext
32+
import org.apache.spark.sql.connect.common.ProtoUtils
2933
import org.apache.spark.sql.connect.common.config.ConnectCommon
3034

3135
/**
@@ -76,6 +80,7 @@ private[sql] class SparkConnectClient(
7680
.setUserContext(userContext)
7781
.setSessionId(sessionId)
7882
.setClientType(userAgent)
83+
.addAllTags(tags.get.toSeq.asJava)
7984
.build()
8085
bstub.executePlan(request)
8186
}
@@ -195,6 +200,59 @@ private[sql] class SparkConnectClient(
195200
bstub.interrupt(request)
196201
}
197202

203+
private[sql] def interruptTag(tag: String): proto.InterruptResponse = {
204+
val builder = proto.InterruptRequest.newBuilder()
205+
val request = builder
206+
.setUserContext(userContext)
207+
.setSessionId(sessionId)
208+
.setClientType(userAgent)
209+
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG)
210+
.setOperationTag(tag)
211+
.build()
212+
bstub.interrupt(request)
213+
}
214+
215+
private[sql] def interruptOperation(id: String): proto.InterruptResponse = {
216+
val builder = proto.InterruptRequest.newBuilder()
217+
val request = builder
218+
.setUserContext(userContext)
219+
.setSessionId(sessionId)
220+
.setClientType(userAgent)
221+
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID)
222+
.setOperationId(id)
223+
.build()
224+
bstub.interrupt(request)
225+
}
226+
227+
private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
228+
override def childValue(parent: mutable.Set[String]): mutable.Set[String] = {
229+
// Note: make a clone such that changes in the parent tags aren't reflected in
230+
// those of the children threads.
231+
parent.clone()
232+
}
233+
override protected def initialValue(): mutable.Set[String] = new mutable.HashSet[String]()
234+
}
235+
236+
private[sql] def addTag(tag: String): Unit = {
237+
// validation is also done server side, but this will give error earlier.
238+
ProtoUtils.throwIfInvalidTag(tag)
239+
tags.get += tag
240+
}
241+
242+
private[sql] def removeTag(tag: String): Unit = {
243+
// validation is also done server side, but this will give error earlier.
244+
ProtoUtils.throwIfInvalidTag(tag)
245+
tags.get.remove(tag)
246+
}
247+
248+
private[sql] def getTags(): Set[String] = {
249+
tags.get.toSet
250+
}
251+
252+
private[sql] def clearTags(): Unit = {
253+
tags.get.clear()
254+
}
255+
198256
def copy(): SparkConnectClient = configuration.toSparkConnectClient
199257

200258
/**

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ private[sql] class SparkResult[T](
4040
extends AutoCloseable
4141
with Cleanable { self =>
4242

43+
private[this] var opId: String = _
4344
private[this] var numRecords: Int = 0
4445
private[this] var structType: StructType = _
4546
private[this] var arrowSchema: pojo.Schema = _
@@ -72,13 +73,28 @@ private[sql] class SparkResult[T](
7273
}
7374

7475
private def processResponses(
76+
stopOnOperationId: Boolean = false,
7577
stopOnSchema: Boolean = false,
7678
stopOnArrowSchema: Boolean = false,
7779
stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
7880
var nonEmpty = false
7981
var stop = false
8082
while (!stop && responses.hasNext) {
8183
val response = responses.next()
84+
85+
// Save and validate operationId
86+
if (opId == null) {
87+
opId = response.getOperationId
88+
}
89+
if (opId != response.getOperationId) {
90+
// backwards compatibility:
91+
// response from an old server without operationId field would have getOperationId == "".
92+
throw new IllegalStateException(
93+
"Received response with wrong operationId. " +
94+
s"Expected '$opId' but received '${response.getOperationId}'.")
95+
}
96+
stop |= stopOnOperationId
97+
8298
if (response.hasSchema) {
8399
// The original schema should arrive before ArrowBatches.
84100
structType =
@@ -148,6 +164,17 @@ private[sql] class SparkResult[T](
148164
structType
149165
}
150166

167+
/**
168+
* @return
169+
* the operationId of the result.
170+
*/
171+
def operationId: String = {
172+
if (opId == null) {
173+
processResponses(stopOnOperationId = true)
174+
}
175+
opId
176+
}
177+
151178
/**
152179
* Create an Array with the contents of the result.
153180
*/

0 commit comments

Comments
 (0)