diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index 1677d3e8e020..a35fccbcffe9 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -52,6 +52,8 @@ See also :class:`SparkSession`. SparkSession.dataSource SparkSession.getActiveSession SparkSession.getTags + SparkSession.interruptAll + SparkSession.interruptTag SparkSession.newSession SparkSession.profile SparkSession.removeTag @@ -86,8 +88,6 @@ Spark Connect Only SparkSession.clearProgressHandlers SparkSession.client SparkSession.copyFromLocalToFs - SparkSession.interruptAll SparkSession.interruptOperation - SparkSession.interruptTag SparkSession.registerProgressHandler SparkSession.removeProgressHandler diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f3a1639fddaf..fc434cd16bfb 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -2197,13 +2197,15 @@ def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None: messageParameters={"feature": "SparkSession.copyFromLocalToFs"}, ) - @remote_only def interruptAll(self) -> List[str]: """ Interrupt all operations of this session currently running on the connected server. .. versionadded:: 3.5.0 + .. versionchanged:: 4.0.0 + Supports Spark Classic. + Returns ------- list of str @@ -2213,18 +2215,25 @@ def interruptAll(self) -> List[str]: ----- There is still a possibility of operation finishing just as it is interrupted. """ - raise PySparkRuntimeError( - errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT", - messageParameters={"feature": "SparkSession.interruptAll"}, - ) + java_list = self._jsparkSession.interruptAll() + python_list = list() + + # Use iterator to manually iterate through Java list + java_iterator = java_list.iterator() + while java_iterator.hasNext(): + python_list.append(str(java_iterator.next())) + + return python_list - @remote_only def interruptTag(self, tag: str) -> List[str]: """ Interrupt all operations of this session with the given operation tag. .. versionadded:: 3.5.0 + .. versionchanged:: 4.0.0 + Supports Spark Classic. + Returns ------- list of str @@ -2234,10 +2243,15 @@ def interruptTag(self, tag: str) -> List[str]: ----- There is still a possibility of operation finishing just as it is interrupted. """ - raise PySparkRuntimeError( - errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT", - messageParameters={"feature": "SparkSession.interruptTag"}, - ) + java_list = self._jsparkSession.interruptTag(tag) + python_list = list() + + # Use iterator to manually iterate through Java list + java_iterator = java_list.iterator() + while java_iterator.hasNext(): + python_list.append(str(java_iterator.next())) + + return python_list @remote_only def interruptOperation(self, op_id: str) -> List[str]: diff --git a/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py b/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py index c5184b04d6aa..ddb4554afa55 100644 --- a/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py +++ b/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py @@ -32,28 +32,6 @@ def func(target): create_thread=lambda target, session: threading.Thread(target=func, args=(target,)) ) - def test_interrupt_tag(self): - thread_ids = range(4) - self.check_job_cancellation( - lambda job_group: self.spark.addTag(job_group), - lambda job_group: self.spark.interruptTag(job_group), - thread_ids, - [i for i in thread_ids if i % 2 == 0], - [i for i in thread_ids if i % 2 != 0], - ) - self.spark.clearTags() - - def test_interrupt_all(self): - thread_ids = range(4) - self.check_job_cancellation( - lambda job_group: None, - lambda job_group: self.spark.interruptAll(), - thread_ids, - thread_ids, - [], - ) - self.spark.clearTags() - if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index ef83dc3834d0..25b8be1f9ac7 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -266,9 +266,7 @@ def test_spark_session_compatibility(self): "addArtifacts", "clearProgressHandlers", "copyFromLocalToFs", - "interruptAll", "interruptOperation", - "interruptTag", "newSession", "registerProgressHandler", "removeProgressHandler", diff --git a/python/pyspark/sql/tests/test_job_cancellation.py b/python/pyspark/sql/tests/test_job_cancellation.py index a046c9c01811..3f30f7880889 100644 --- a/python/pyspark/sql/tests/test_job_cancellation.py +++ b/python/pyspark/sql/tests/test_job_cancellation.py @@ -166,6 +166,28 @@ def get_outer_local_prop(): self.assertEqual(first, {"a", "b"}) self.assertEqual(second, {"a", "b", "c"}) + def test_interrupt_tag(self): + thread_ids = range(4) + self.check_job_cancellation( + lambda job_group: self.spark.addTag(job_group), + lambda job_group: self.spark.interruptTag(job_group), + thread_ids, + [i for i in thread_ids if i % 2 == 0], + [i for i in thread_ids if i % 2 != 0], + ) + self.spark.clearTags() + + def test_interrupt_all(self): + thread_ids = range(4) + self.check_job_cancellation( + lambda job_group: None, + lambda job_group: self.spark.interruptAll(), + thread_ids, + thread_ids, + [], + ) + self.spark.clearTags() + class JobCancellationTests(JobCancellationTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 3fbc0be943e4..a22fe777e3c9 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -227,7 +227,6 @@ def test_unsupported_api(self): (lambda: session.client, "client"), (session.addArtifacts, "addArtifact(s)"), (lambda: session.copyFromLocalToFs("", ""), "copyFromLocalToFs"), - (lambda: session.interruptTag(""), "interruptTag"), (lambda: session.interruptOperation(""), "interruptOperation"), ]