From 05bcf4c75dcc518d9e87e744f7999eda696d665d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 18 Jun 2024 12:35:26 +0900 Subject: [PATCH] Make tags properly threadlocal --- python/pyspark/sql/connect/client/core.py | 9 ++---- .../pyspark/sql/tests/connect/test_session.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 4c638be3b0af..f3bbab69f271 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -659,12 +659,7 @@ def __init__( use_reattachable_execute: bool Enable reattachable execution. """ - - class ClientThreadLocals(threading.local): - tags: set = set() - inside_error_handling: bool = False - - self.thread_local = ClientThreadLocals() + self.thread_local = threading.local() # Parse the connection string. self._builder = ( @@ -1693,7 +1688,7 @@ def _handle_error(self, error: Exception) -> NoReturn: Throws the appropriate internal Python exception. """ - if self.thread_local.inside_error_handling: + if getattr(self.thread_local, "inside_error_handling", False): # We are already inside error handling routine, # avoid recursive error processing (with potentially infinite recursion) raise error diff --git a/python/pyspark/sql/tests/connect/test_session.py b/python/pyspark/sql/tests/connect/test_session.py index 820f54b83327..6f0e4aaad3f8 100644 --- a/python/pyspark/sql/tests/connect/test_session.py +++ b/python/pyspark/sql/tests/connect/test_session.py @@ -119,6 +119,34 @@ def test_tags(self): self.assertEqual(self.spark.getTags(), set()) self.spark.clearTags() + def test_tags_multithread(self): + output1 = None + output2 = None + + def tag1(): + nonlocal output1 + + self.spark.addTag("tag1") + output1 = self.spark.getTags() + + def tag2(): + nonlocal output2 + + self.spark.addTag("tag2") + output2 = self.spark.getTags() + + t1 = threading.Thread(target=tag1) + t1.start() + t1.join() + t2 = threading.Thread(target=tag2) + t2.start() + t2.join() + + self.assertIsNotNone(output1) + self.assertEquals(output1, {"tag1"}) + self.assertIsNotNone(output2) + self.assertEquals(output2, {"tag2"}) + def test_interrupt_tag(self): thread_ids = range(4) self.check_job_cancellation(