Skip to content

Commit 31332e2

Browse files
juliuszsompolskiHyukjinKwon
authored andcommitted
[SPARK-43755][CONNECT] Move execution out of SparkExecutePlanStreamHandler and to a different thread
### What changes were proposed in this pull request? * Move code related to execution from being done directly in the GRPC callback in SparkConnextStreamHandler, to it's own classes. * `ExecutionHolder` (renamed from `ExecuteHolder`) launches the execution in it's own thread using `ExecuteThreadRunner` * The execution pushes responses via `ExecuteResponseObserver` (running in the execution thread) * `ExecuteResponseObserver` notifies `ExecuteGrpcResponseSender` (running in the rpc handler thread) to send the responses. * The actual execution code is refactored into `SparkConnectPlanExecution` This allows to improve query interruption, by making `interrupt` method interrupt the execution thread. This makes `interrupt` work also when no Spark Jobs are running. The refactoring further opens the possibilities of detaching query execution from a single RPC execution. Right now `ExecutionHolder` is waiting for the execution thread to finish, and `ExecutePlanResponseObserver` is forwarding the responses directly to the RPC observer. In a followup, we can design different modes of execution, e.g. `ExecutePlanResponseObserver` buffering the responses. A client which lost connection could then reconnect and ask for the stream to be retransmitted. * `ExecutionHolder` returning the operationId to the client directly, and then client requesting results in separate RPCs, with more control over the response stream, instead of having it just pushed to it. ### Why are the changes needed? * Improve the working of `interrupt` * Refactoring that opens up possibilities of detaching query execution from a single RPC. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing Spark Connect CI covers the execution. Closes apache#41315 from juliuszsompolski/sc-execute-thread. Lead-authored-by: Juliusz Sompolski <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Yuanjian Li <[email protected]>
1 parent 9698d1e commit 31332e2

File tree

17 files changed

+905
-300
lines changed

17 files changed

+905
-300
lines changed

common/utils/src/main/resources/error/error-classes.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,6 +2166,12 @@
21662166
"Number of given aliases does not match number of output columns. Function name: <funcName>; number of aliases: <aliasesNum>; number of output columns: <outColsNum>."
21672167
]
21682168
},
2169+
"OPERATION_CANCELED" : {
2170+
"message" : [
2171+
"Operation has been canceled."
2172+
],
2173+
"sqlState" : "HY008"
2174+
},
21692175
"ORDER_BY_POS_OUT_OF_RANGE" : {
21702176
"message" : [
21712177
"ORDER BY position <index> is not in select list (valid range is [1, <size>])."

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
4949
q1.onComplete {
5050
case Success(_) =>
5151
error = Some("q1 shouldn't have finished!")
52-
case Failure(t) if t.getMessage.contains("cancelled") =>
52+
case Failure(t) if t.getMessage.contains("OPERATION_CANCELED") =>
5353
q1Interrupted = true
5454
case Failure(t) =>
5555
error = Some("unexpected failure in q1: " + t.toString)
5656
}
5757
q2.onComplete {
5858
case Success(_) =>
5959
error = Some("q2 shouldn't have finished!")
60-
case Failure(t) if t.getMessage.contains("cancelled") =>
60+
case Failure(t) if t.getMessage.contains("OPERATION_CANCELED") =>
6161
q2Interrupted = true
6262
case Failure(t) =>
6363
error = Some("unexpected failure in q2: " + t.toString)
@@ -89,11 +89,11 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
8989
val e1 = intercept[SparkException] {
9090
spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect()
9191
}
92-
assert(e1.getMessage.contains("cancelled"), s"Unexpected exception: $e1")
92+
assert(e1.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e1")
9393
val e2 = intercept[SparkException] {
9494
spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect()
9595
}
96-
assert(e2.getMessage.contains("cancelled"), s"Unexpected exception: $e2")
96+
assert(e2.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e2")
9797
finished = true
9898
assert(ThreadUtils.awaitResult(interruptor, 10.seconds))
9999
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.execution
19+
20+
private[execution] case class CachedStreamResponse[T](
21+
// the actual cached response
22+
response: T,
23+
// index of the response in the response stream.
24+
// responses produced in the stream are numbered consecutively starting from 1.
25+
streamIndex: Long)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.execution
19+
20+
import io.grpc.stub.StreamObserver
21+
22+
import org.apache.spark.internal.Logging
23+
24+
/**
25+
* ExecuteGrpcResponseSender sends responses to the GRPC stream. It runs on the RPC thread, and
26+
* gets notified by ExecuteResponseObserver about available responses. It notifies the
27+
* ExecuteResponseObserver back about cached responses that can be removed after being sent out.
28+
* @param responseObserver
29+
* the GRPC request StreamObserver
30+
*/
31+
private[connect] class ExecuteGrpcResponseSender[T](grpcObserver: StreamObserver[T])
32+
extends Logging {
33+
34+
private var detached = false
35+
36+
/**
37+
* Detach this sender from executionObserver. Called only from executionObserver that this
38+
* sender is attached to. executionObserver holds lock, and needs to notify after this call.
39+
*/
40+
def detach(): Unit = {
41+
if (detached == true) {
42+
throw new IllegalStateException("ExecuteGrpcResponseSender already detached!")
43+
}
44+
detached = true
45+
}
46+
47+
/**
48+
* Attach to the executionObserver, consume responses from it, and send them to grpcObserver.
49+
* @param lastConsumedStreamIndex
50+
* the last index that was already consumed and sent. This sender will start from index after
51+
* that. 0 means start from beginning (since first response has index 1)
52+
*
53+
* @return
54+
* true if the execution was detached before stream completed. The caller needs to finish the
55+
* grpcObserver stream false if stream was finished. In this case, grpcObserver stream is
56+
* already completed.
57+
*/
58+
def run(
59+
executionObserver: ExecuteResponseObserver[T],
60+
lastConsumedStreamIndex: Long): Boolean = {
61+
// register to be notified about available responses.
62+
executionObserver.attachConsumer(this)
63+
64+
var nextIndex = lastConsumedStreamIndex + 1
65+
var finished = false
66+
67+
while (!finished) {
68+
var response: Option[CachedStreamResponse[T]] = None
69+
// Get next available response.
70+
// Wait until either this sender got detached or next response is ready,
71+
// or the stream is complete and it had already sent all responses.
72+
logDebug(s"Trying to get next response with index=$nextIndex.")
73+
executionObserver.synchronized {
74+
logDebug(s"Acquired lock.")
75+
while (!detached && response.isEmpty &&
76+
executionObserver.getLastIndex().forall(nextIndex <= _)) {
77+
logDebug(s"Try to get response with index=$nextIndex from observer.")
78+
response = executionObserver.getResponse(nextIndex)
79+
logDebug(s"Response index=$nextIndex from observer: ${response.isDefined}")
80+
// If response is empty, release executionObserver lock and wait to get notified.
81+
// The state of detached, response and lastIndex are change under lock in
82+
// executionObserver, and will notify upon state change.
83+
if (response.isEmpty) {
84+
logDebug(s"Wait for response to become available.")
85+
executionObserver.wait()
86+
logDebug(s"Reacquired lock after waiting.")
87+
}
88+
}
89+
logDebug(
90+
s"Exiting loop: detached=$detached, response=$response," +
91+
s"lastIndex=${executionObserver.getLastIndex()}")
92+
}
93+
94+
// Send next available response.
95+
if (detached) {
96+
// This sender got detached by the observer.
97+
logDebug(s"Detached from observer at index ${nextIndex - 1}. Complete stream.")
98+
finished = true
99+
} else if (response.isDefined) {
100+
// There is a response available to be sent.
101+
grpcObserver.onNext(response.get.response)
102+
logDebug(s"Sent response index=$nextIndex.")
103+
nextIndex += 1
104+
} else if (executionObserver.getLastIndex().forall(nextIndex > _)) {
105+
// Stream is finished and all responses have been sent
106+
logDebug(s"Sent all responses up to index ${nextIndex - 1}.")
107+
executionObserver.getError() match {
108+
case Some(t) => grpcObserver.onError(t)
109+
case None => grpcObserver.onCompleted()
110+
}
111+
finished = true
112+
}
113+
}
114+
// Return true if stream finished, or false if was detached.
115+
detached
116+
}
117+
}
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.execution
19+
20+
import scala.collection.mutable
21+
22+
import io.grpc.stub.StreamObserver
23+
24+
import org.apache.spark.internal.Logging
25+
26+
/**
27+
* This StreamObserver is running on the execution thread. Execution pushes responses to it, it
28+
* caches them. ExecuteResponseGRPCSender is the consumer of the responses ExecuteResponseObserver
29+
* "produces". It waits on the monitor of ExecuteResponseObserver. New produced responses notify
30+
* the monitor.
31+
* @see
32+
* getResponse.
33+
*
34+
* ExecuteResponseObserver controls how responses stay cached after being returned to consumer,
35+
* @see
36+
* removeCachedResponses.
37+
*
38+
* A single ExecuteResponseGRPCSender can be attached to the ExecuteResponseObserver. Attaching a
39+
* new one will notify an existing one that it was detached.
40+
* @see
41+
* attachConsumer
42+
*/
43+
private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] with Logging {
44+
45+
/**
46+
* Cached responses produced by the execution. Map from response index -> response. Response
47+
* indexes are numbered consecutively starting from 1.
48+
*/
49+
private val responses: mutable.Map[Long, CachedStreamResponse[T]] =
50+
new mutable.HashMap[Long, CachedStreamResponse[T]]()
51+
52+
/** Cached error of the execution, if an error was thrown. */
53+
private var error: Option[Throwable] = None
54+
55+
/**
56+
* If execution stream is finished (completed or with error), the index of the final response.
57+
*/
58+
private var finalProducedIndex: Option[Long] = None // index of final response before completed.
59+
60+
/** The index of the last response produced by execution. */
61+
private var lastProducedIndex: Long = 0 // first response will have index 1
62+
63+
/**
64+
* Highest response index that was consumed. Keeps track of it to decide which responses needs
65+
* to be cached, and to assert that all responses are consumed.
66+
*/
67+
private var highestConsumedIndex: Long = 0
68+
69+
/**
70+
* Consumer that waits for available responses. There can be only one at a time, @see
71+
* attachConsumer.
72+
*/
73+
private var responseSender: Option[ExecuteGrpcResponseSender[T]] = None
74+
75+
def onNext(r: T): Unit = synchronized {
76+
if (finalProducedIndex.nonEmpty) {
77+
throw new IllegalStateException("Stream onNext can't be called after stream completed")
78+
}
79+
lastProducedIndex += 1
80+
responses += ((lastProducedIndex, CachedStreamResponse[T](r, lastProducedIndex)))
81+
logDebug(s"Saved response with index=$lastProducedIndex")
82+
notifyAll()
83+
}
84+
85+
def onError(t: Throwable): Unit = synchronized {
86+
if (finalProducedIndex.nonEmpty) {
87+
throw new IllegalStateException("Stream onError can't be called after stream completed")
88+
}
89+
error = Some(t)
90+
finalProducedIndex = Some(lastProducedIndex) // no responses to be send after error.
91+
logDebug(s"Error. Last stream index is $lastProducedIndex.")
92+
notifyAll()
93+
}
94+
95+
def onCompleted(): Unit = synchronized {
96+
if (finalProducedIndex.nonEmpty) {
97+
throw new IllegalStateException("Stream onCompleted can't be called after stream completed")
98+
}
99+
finalProducedIndex = Some(lastProducedIndex)
100+
logDebug(s"Completed. Last stream index is $lastProducedIndex.")
101+
notifyAll()
102+
}
103+
104+
/** Attach a new consumer (ExecuteResponseGRPCSender). */
105+
def attachConsumer(newSender: ExecuteGrpcResponseSender[T]): Unit = synchronized {
106+
// detach the current sender before attaching new one
107+
// this.synchronized() needs to be held while detaching a sender, and the detached sender
108+
// needs to be notified with notifyAll() afterwards.
109+
responseSender.foreach(_.detach())
110+
responseSender = Some(newSender)
111+
notifyAll() // consumer
112+
}
113+
114+
/** Get response with a given index in the stream, if set. */
115+
def getResponse(index: Long): Option[CachedStreamResponse[T]] = synchronized {
116+
// we index stream responses from 1, getting a lower index would be invalid.
117+
assert(index >= 1)
118+
// it would be invalid if consumer would skip a response
119+
assert(index <= highestConsumedIndex + 1)
120+
val ret = responses.get(index)
121+
if (ret.isDefined) {
122+
if (index > highestConsumedIndex) highestConsumedIndex = index
123+
removeCachedResponses()
124+
}
125+
ret
126+
}
127+
128+
/** Get the stream error if there is one, otherwise None. */
129+
def getError(): Option[Throwable] = synchronized {
130+
error
131+
}
132+
133+
/** If the stream is finished, the index of the last response, otherwise None. */
134+
def getLastIndex(): Option[Long] = synchronized {
135+
finalProducedIndex
136+
}
137+
138+
/** Returns if the stream is finished. */
139+
def completed(): Boolean = synchronized {
140+
finalProducedIndex.isDefined
141+
}
142+
143+
/** Consumer (ExecuteResponseGRPCSender) waits on the monitor of ExecuteResponseObserver. */
144+
private def notifyConsumer(): Unit = {
145+
notifyAll()
146+
}
147+
148+
/**
149+
* Remove cached responses after response with lastReturnedIndex is returned from getResponse.
150+
* Remove according to caching policy:
151+
* - if query is not reattachable, remove all responses up to and including
152+
* highestConsumedIndex.
153+
*/
154+
private def removeCachedResponses() = {
155+
var i = highestConsumedIndex
156+
while (i >= 1 && responses.get(i).isDefined) {
157+
responses.remove(i)
158+
i -= 1
159+
}
160+
}
161+
}

0 commit comments

Comments
 (0)