diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 1276c31b33737..30ad04297c682 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -228,6 +228,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): def handle(self): from pyspark.accumulators import _accumulatorRegistry auth_token = self.server.auth_token + def poll(func): while not self.server.server_shutdown: # Poll every 1 second for new data -- don't block in case of shutdown. @@ -254,13 +255,15 @@ def authenticate_and_accum_updates(): # we've authenticated, we can break out of the first loop now return True else: - raise Exception("The value of the provided token to the AccumulatorServer is not correct.") + raise Exception( + "The value of the provided token to the AccumulatorServer is not correct.") # first we keep polling till we've received the authentication token poll(authenticate_and_accum_updates) # now we've authenticated, don't need to check for the token anymore poll(accum_updates) + class AccumulatorServer(SocketServer.TCPServer): def __init__(self, server_address, RequestHandlerClass, auth_token): @@ -278,6 +281,7 @@ def shutdown(self): SocketServer.TCPServer.shutdown(self) self.server_close() + def _start_update_server(auth_token): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token)