Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,64 +54,6 @@
}


class Py4jCallbackConnectionCleaner(object):

"""
A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617.
It will scan all callback connections every 30 seconds and close the dead connections.
"""

def __init__(self, gateway):
self._gateway = gateway
self._stopped = False
self._timer = None
self._lock = RLock()

def start(self):
if self._stopped:
return

def clean_closed_connections():
from py4j.java_gateway import quiet_close, quiet_shutdown

callback_server = self._gateway._callback_server
with callback_server.lock:
try:
closed_connections = []
for connection in callback_server.connections:
if not connection.isAlive():
quiet_close(connection.input)
quiet_shutdown(connection.socket)
quiet_close(connection.socket)
closed_connections.append(connection)

for closed_connection in closed_connections:
callback_server.connections.remove(closed_connection)
except Exception:
import traceback
traceback.print_exc()

self._start_timer(clean_closed_connections)

self._start_timer(clean_closed_connections)

def _start_timer(self, f):
from threading import Timer

with self._lock:
if not self._stopped:
self._timer = Timer(30.0, f)
self._timer.daemon = True
self._timer.start()

def stop(self):
with self._lock:
self._stopped = True
if self._timer:
self._timer.cancel()
self._timer = None


class SparkContext(object):

"""
Expand All @@ -126,7 +68,6 @@ class SparkContext(object):
_active_spark_context = None
_lock = RLock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
_py4j_cleaner = None

PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')

Expand Down Expand Up @@ -303,8 +244,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
_py4j_cleaner = Py4jCallbackConnectionCleaner(SparkContext._gateway)
_py4j_cleaner.start()

if instance:
if (SparkContext._active_spark_context and
Expand Down
63 changes: 63 additions & 0 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import os
import sys
from threading import RLock, Timer

from py4j.java_gateway import java_import, JavaObject

Expand All @@ -32,6 +33,63 @@
__all__ = ["StreamingContext"]


class Py4jCallbackConnectionCleaner(object):

"""
A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617.
It will scan all callback connections every 30 seconds and close the dead connections.
"""

def __init__(self, gateway):
self._gateway = gateway
self._stopped = False
self._timer = None
self._lock = RLock()

def start(self):
if self._stopped:
return

def clean_closed_connections():
from py4j.java_gateway import quiet_close, quiet_shutdown

callback_server = self._gateway._callback_server
if callback_server:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a defensive check

with callback_server.lock:
try:
closed_connections = []
for connection in callback_server.connections:
if not connection.isAlive():
quiet_close(connection.input)
quiet_shutdown(connection.socket)
quiet_close(connection.socket)
closed_connections.append(connection)

for closed_connection in closed_connections:
callback_server.connections.remove(closed_connection)
except Exception:
import traceback
traceback.print_exc()

self._start_timer(clean_closed_connections)

self._start_timer(clean_closed_connections)

def _start_timer(self, f):
with self._lock:
if not self._stopped:
self._timer = Timer(30.0, f)
self._timer.daemon = True
self._timer.start()

def stop(self):
with self._lock:
self._stopped = True
if self._timer:
self._timer.cancel()
self._timer = None


class StreamingContext(object):
"""
Main entry point for Spark Streaming functionality. A StreamingContext
Expand All @@ -47,6 +105,9 @@ class StreamingContext(object):
# Reference to a currently active StreamingContext
_activeContext = None

# A cleaner to clean leak sockets of callback server every 30 seconds
_py4j_cleaner = None

def __init__(self, sparkContext, batchDuration=None, jssc=None):
"""
Create a new StreamingContext.
Expand Down Expand Up @@ -95,6 +156,8 @@ def _ensure_initialized(cls):
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
# update the port of CallbackClient with real port
gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
_py4j_cleaner = Py4jCallbackConnectionCleaner(gw)
_py4j_cleaner.start()

# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
Expand Down