3838import select
3939import socket
4040import struct
41- import threading
41+ import thread
4242import time
4343import warnings
4444
6060_CONNECT_TIMEOUT = 20.0
6161
6262
63+ try :
64+ from greenlet import greenlet
65+ except ImportError :
66+ def _thread_identifier ():
67+ """Return the identifier of the current thread-of-execution."""
68+ return os .getpid (), thread .get_ident ()
69+ else :
70+ def _thread_identifier ():
71+ """Return the identifier of the current thread-of-execution.
72+ Supports greenlets.
73+ """
74+ return os .getpid (), thread .get_ident (), greenlet .getcurrent ()
75+
76+
77+
6378def _closed (sock ):
6479 """Return True if we know socket has been closed, False otherwise.
6580 """
@@ -84,25 +99,20 @@ def _partition_node(node):
8499 return host , port
85100
86101
87- class _Pool (threading .local ):
102+
103+ class _Pool (object ):
88104 """A simple connection pool.
89105
90- Uses thread-local socket per thread. By calling return_socket() a
91- thread can return a socket to the pool.
106+ Uses thread-local socket per thread (including greenlets).
107+ By calling return_socket() a thread can return a socket to the pool.
92108 """
93109
94- # Non thread-locals
95- __slots__ = ["sockets" , "pool_size" , "pid" ]
96-
97- # thread-local default
98- sock = None
99-
100110 def __init__ (self , pool_size , network_timeout ):
101111 self .pid = os .getpid ()
102112 self .pool_size = pool_size
103113 self .network_timeout = network_timeout
104- if not hasattr ( self , " sockets" ):
105- self .sockets = []
114+ self . sockets = []
115+ self .active_sockets = {}
106116
107117 def connect (self , host , port ):
108118 """Connect to Mongo and return a new (connected) socket.
@@ -126,36 +136,30 @@ def connect(self, host, port):
126136 return s
127137
128138 def get_socket (self , host , port ):
129- # We use the pid here to avoid issues with fork / multiprocessing.
139+ # We use the _thread_identifier here to avoid issues with multiple
140+ # threads of execution (processes, proper threads, greenlets)
130141 # See test.test_connection:TestConnection.test_fork for an example of
131142 # what could go wrong otherwise
132- pid = os .getpid ()
133-
134- if pid != self .pid :
135- self .sock = None
136- self .sockets = []
137- self .pid = pid
138-
139- if self .sock is not None and self .sock [0 ] == pid :
140- return self .sock [1 ]
143+ sock_id = _thread_identifier ()
141144
142145 try :
143- self .sock = (pid , self .sockets .pop ())
144- except IndexError :
145- self .sock = (pid , self .connect (host , port ))
146-
147- return self .sock [1 ]
146+ sock = self .active_sockets [sock_id ]
147+ except KeyError :
148+ try :
149+ sock = self .sockets .pop ()
150+ except IndexError :
151+ sock = self .connect (host , port )
152+ self .active_sockets [sock_id ] = sock
153+ return sock
148154
149155 def return_socket (self ):
150- if self .sock is not None and self .sock [0 ] == os .getpid ():
156+ sock = self .active_sockets .pop (_thread_identifier (), None )
157+ if sock is not None :
151158 # There's a race condition here, but we deliberately
152159 # ignore it. It means that if the pool_size is 10 we
153160 # might actually keep slightly more than that.
154161 if len (self .sockets ) < self .pool_size :
155- self .sockets .append (self .sock [1 ])
156- else :
157- self .sock [1 ].close ()
158- self .sock = None
162+ self .sockets .append (sock )
159163
160164
161165class Connection (common .BaseObject ): # TODO support auth for pooling
0 commit comments