Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-50357][PYTHON] Support Interrupt(Tag|All) APIs for PySpark
  • Loading branch information
itholic committed Jan 3, 2025
commit 6a9f8aaff31e41ef398a11a600b6eb06e9bae9d6
34 changes: 24 additions & 10 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,13 +2197,15 @@ def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
messageParameters={"feature": "SparkSession.copyFromLocalToFs"},
)

@remote_only
def interruptAll(self) -> List[str]:
"""
Interrupt all operations of this session currently running on the connected server.

.. versionadded:: 3.5.0

.. versionchanged:: 4.0.0
Supports Spark Classic.

Returns
-------
list of str
Expand All @@ -2213,18 +2215,25 @@ def interruptAll(self) -> List[str]:
-----
There is still a possibility of operation finishing just as it is interrupted.
"""
raise PySparkRuntimeError(
errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
messageParameters={"feature": "SparkSession.interruptAll"},
)
java_list = self._jsparkSession.interruptAll()
python_list = list()

# Use iterator to manually iterate through Java list
java_iterator = java_list.iterator()
while java_iterator.hasNext():
python_list.append(str(java_iterator.next()))

return python_list

@remote_only
def interruptTag(self, tag: str) -> List[str]:
"""
Interrupt all operations of this session with the given operation tag.

.. versionadded:: 3.5.0

.. versionchanged:: 4.0.0
Supports Spark Classic.

Returns
-------
list of str
Expand All @@ -2234,10 +2243,15 @@ def interruptTag(self, tag: str) -> List[str]:
-----
There is still a possibility of operation finishing just as it is interrupted.
"""
raise PySparkRuntimeError(
errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
messageParameters={"feature": "SparkSession.interruptTag"},
)
java_list = self._jsparkSession.interruptTag(tag)
python_list = list()

# Use iterator to manually iterate through Java list
java_iterator = java_list.iterator()
while java_iterator.hasNext():
python_list.append(str(java_iterator.next()))

return python_list

@remote_only
def interruptOperation(self, op_id: str) -> List[str]:
Expand Down
22 changes: 0 additions & 22 deletions python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,6 @@ def func(target):
create_thread=lambda target, session: threading.Thread(target=func, args=(target,))
)

def test_interrupt_tag(self):
thread_ids = range(4)
self.check_job_cancellation(
lambda job_group: self.spark.addTag(job_group),
lambda job_group: self.spark.interruptTag(job_group),
thread_ids,
[i for i in thread_ids if i % 2 == 0],
[i for i in thread_ids if i % 2 != 0],
)
self.spark.clearTags()

def test_interrupt_all(self):
thread_ids = range(4)
self.check_job_cancellation(
lambda job_group: None,
lambda job_group: self.spark.interruptAll(),
thread_ids,
thread_ids,
[],
)
self.spark.clearTags()


if __name__ == "__main__":
import unittest
Expand Down
2 changes: 0 additions & 2 deletions python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,7 @@ def test_spark_session_compatibility(self):
"addArtifacts",
"clearProgressHandlers",
"copyFromLocalToFs",
"interruptAll",
"interruptOperation",
"interruptTag",
"newSession",
"registerProgressHandler",
"removeProgressHandler",
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/sql/tests/test_job_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,28 @@ def get_outer_local_prop():
self.assertEqual(first, {"a", "b"})
self.assertEqual(second, {"a", "b", "c"})

def test_interrupt_tag(self):
thread_ids = range(4)
self.check_job_cancellation(
lambda job_group: self.spark.addTag(job_group),
lambda job_group: self.spark.interruptTag(job_group),
thread_ids,
[i for i in thread_ids if i % 2 == 0],
[i for i in thread_ids if i % 2 != 0],
)
self.spark.clearTags()

def test_interrupt_all(self):
thread_ids = range(4)
self.check_job_cancellation(
lambda job_group: None,
lambda job_group: self.spark.interruptAll(),
thread_ids,
thread_ids,
[],
)
self.spark.clearTags()


class JobCancellationTests(JobCancellationTestsMixin, ReusedSQLTestCase):
pass
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def test_unsupported_api(self):
(lambda: session.client, "client"),
(session.addArtifacts, "addArtifact(s)"),
(lambda: session.copyFromLocalToFs("", ""), "copyFromLocalToFs"),
(lambda: session.interruptTag(""), "interruptTag"),
(lambda: session.interruptOperation(""), "interruptOperation"),
]

Expand Down
Loading