Skip to content

Commit 459483a

Browse files
committed
[SPARK-50357][PYTHON] Support Interrupt(Tag|All) APIs for PySpark
### What changes were proposed in this pull request? This PR proposes to support `Interrupt(Tag|All)` for PySpark ### Why are the changes needed? To improve the compatibility between Spark Connect and Spark Classic. ### Does this PR introduce _any_ user-facing change? New APIs are added - InterruptTag - InterruptAll ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #49014 from itholic/SPARK-50357. Authored-by: Haejoon Lee <[email protected]> Signed-off-by: Haejoon Lee <[email protected]>
1 parent 56284bf commit 459483a

File tree

6 files changed

+48
-37
lines changed

6 files changed

+48
-37
lines changed

python/docs/source/reference/pyspark.sql/spark_session.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ See also :class:`SparkSession`.
5252
SparkSession.dataSource
5353
SparkSession.getActiveSession
5454
SparkSession.getTags
55+
SparkSession.interruptAll
56+
SparkSession.interruptTag
5557
SparkSession.newSession
5658
SparkSession.profile
5759
SparkSession.removeTag
@@ -86,8 +88,6 @@ Spark Connect Only
8688
SparkSession.clearProgressHandlers
8789
SparkSession.client
8890
SparkSession.copyFromLocalToFs
89-
SparkSession.interruptAll
9091
SparkSession.interruptOperation
91-
SparkSession.interruptTag
9292
SparkSession.registerProgressHandler
9393
SparkSession.removeProgressHandler

python/pyspark/sql/session.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,13 +2197,15 @@ def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
21972197
messageParameters={"feature": "SparkSession.copyFromLocalToFs"},
21982198
)
21992199

2200-
@remote_only
22012200
def interruptAll(self) -> List[str]:
22022201
"""
22032202
Interrupt all operations of this session currently running on the connected server.
22042203
22052204
.. versionadded:: 3.5.0
22062205
2206+
.. versionchanged:: 4.0.0
2207+
Supports Spark Classic.
2208+
22072209
Returns
22082210
-------
22092211
list of str
@@ -2213,18 +2215,25 @@ def interruptAll(self) -> List[str]:
22132215
-----
22142216
There is still a possibility of operation finishing just as it is interrupted.
22152217
"""
2216-
raise PySparkRuntimeError(
2217-
errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
2218-
messageParameters={"feature": "SparkSession.interruptAll"},
2219-
)
2218+
java_list = self._jsparkSession.interruptAll()
2219+
python_list = list()
2220+
2221+
# Use iterator to manually iterate through Java list
2222+
java_iterator = java_list.iterator()
2223+
while java_iterator.hasNext():
2224+
python_list.append(str(java_iterator.next()))
2225+
2226+
return python_list
22202227

2221-
@remote_only
22222228
def interruptTag(self, tag: str) -> List[str]:
22232229
"""
22242230
Interrupt all operations of this session with the given operation tag.
22252231
22262232
.. versionadded:: 3.5.0
22272233
2234+
.. versionchanged:: 4.0.0
2235+
Supports Spark Classic.
2236+
22282237
Returns
22292238
-------
22302239
list of str
@@ -2234,10 +2243,15 @@ def interruptTag(self, tag: str) -> List[str]:
22342243
-----
22352244
There is still a possibility of operation finishing just as it is interrupted.
22362245
"""
2237-
raise PySparkRuntimeError(
2238-
errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
2239-
messageParameters={"feature": "SparkSession.interruptTag"},
2240-
)
2246+
java_list = self._jsparkSession.interruptTag(tag)
2247+
python_list = list()
2248+
2249+
# Use iterator to manually iterate through Java list
2250+
java_iterator = java_list.iterator()
2251+
while java_iterator.hasNext():
2252+
python_list.append(str(java_iterator.next()))
2253+
2254+
return python_list
22412255

22422256
@remote_only
22432257
def interruptOperation(self, op_id: str) -> List[str]:

python/pyspark/sql/tests/connect/test_parity_job_cancellation.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,6 @@ def func(target):
3232
create_thread=lambda target, session: threading.Thread(target=func, args=(target,))
3333
)
3434

35-
def test_interrupt_tag(self):
36-
thread_ids = range(4)
37-
self.check_job_cancellation(
38-
lambda job_group: self.spark.addTag(job_group),
39-
lambda job_group: self.spark.interruptTag(job_group),
40-
thread_ids,
41-
[i for i in thread_ids if i % 2 == 0],
42-
[i for i in thread_ids if i % 2 != 0],
43-
)
44-
self.spark.clearTags()
45-
46-
def test_interrupt_all(self):
47-
thread_ids = range(4)
48-
self.check_job_cancellation(
49-
lambda job_group: None,
50-
lambda job_group: self.spark.interruptAll(),
51-
thread_ids,
52-
thread_ids,
53-
[],
54-
)
55-
self.spark.clearTags()
56-
5735

5836
if __name__ == "__main__":
5937
import unittest

python/pyspark/sql/tests/test_connect_compatibility.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,7 @@ def test_spark_session_compatibility(self):
266266
"addArtifacts",
267267
"clearProgressHandlers",
268268
"copyFromLocalToFs",
269-
"interruptAll",
270269
"interruptOperation",
271-
"interruptTag",
272270
"newSession",
273271
"registerProgressHandler",
274272
"removeProgressHandler",

python/pyspark/sql/tests/test_job_cancellation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,28 @@ def get_outer_local_prop():
166166
self.assertEqual(first, {"a", "b"})
167167
self.assertEqual(second, {"a", "b", "c"})
168168

169+
def test_interrupt_tag(self):
170+
thread_ids = range(4)
171+
self.check_job_cancellation(
172+
lambda job_group: self.spark.addTag(job_group),
173+
lambda job_group: self.spark.interruptTag(job_group),
174+
thread_ids,
175+
[i for i in thread_ids if i % 2 == 0],
176+
[i for i in thread_ids if i % 2 != 0],
177+
)
178+
self.spark.clearTags()
179+
180+
def test_interrupt_all(self):
181+
thread_ids = range(4)
182+
self.check_job_cancellation(
183+
lambda job_group: None,
184+
lambda job_group: self.spark.interruptAll(),
185+
thread_ids,
186+
thread_ids,
187+
[],
188+
)
189+
self.spark.clearTags()
190+
169191

170192
class JobCancellationTests(JobCancellationTestsMixin, ReusedSQLTestCase):
171193
pass

python/pyspark/sql/tests/test_session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def test_unsupported_api(self):
227227
(lambda: session.client, "client"),
228228
(session.addArtifacts, "addArtifact(s)"),
229229
(lambda: session.copyFromLocalToFs("", ""), "copyFromLocalToFs"),
230-
(lambda: session.interruptTag(""), "interruptTag"),
231230
(lambda: session.interruptOperation(""), "interruptOperation"),
232231
]
233232

0 commit comments

Comments
 (0)