@@ -209,10 +209,20 @@ def test_get_socket_and_exception(self):
209209
210210 self .assertEqual (1 , len (cx_pool .sockets ))
211211
212+ def test_pool_removes_closed_socket (self ):
213+ # Test that Pool removes explicitly closed socket.
214+ cx_pool = self .create_pool ()
215+
216+ with cx_pool .get_socket ({}, 0 , 0 ) as sock_info :
217+ # Use SocketInfo's API to close the socket.
218+ sock_info .close ()
219+
220+ self .assertEqual (0 , len (cx_pool .sockets ))
221+
212222 def test_pool_removes_dead_socket (self ):
213223 # Test that Pool removes dead socket and the socket doesn't return
214224 # itself PYTHON-344
215- cx_pool = self .create_pool (max_pool_size = 10 )
225+ cx_pool = self .create_pool (max_pool_size = 1 , wait_queue_timeout = 1 )
216226 cx_pool ._check_interval_seconds = 0 # Always check.
217227
218228 with cx_pool .get_socket ({}, 0 , 0 ) as sock_info :
@@ -227,6 +237,42 @@ def test_pool_removes_dead_socket(self):
227237
228238 self .assertEqual (1 , len (cx_pool .sockets ))
229239
240+ # Semaphore was released.
241+ with cx_pool .get_socket ({}, 0 , 0 ):
242+ pass
243+
244+ def test_return_socket_after_reset (self ):
245+ pool = self .create_pool ()
246+ with pool .get_socket ({}, 0 , 0 ) as sock :
247+ pool .reset ()
248+
249+ self .assertTrue (sock .closed )
250+ self .assertEqual (0 , len (pool .sockets ))
251+
252+ def test_pool_check (self ):
253+ # Test that Pool recovers from two connection failures in a row.
254+ # This exercises code at the end of Pool._check().
255+ cx_pool = self .create_pool (max_pool_size = 1 ,
256+ connect_timeout = 1 ,
257+ wait_queue_timeout = 1 )
258+ cx_pool ._check_interval_seconds = 0 # Always check.
259+
260+ with cx_pool .get_socket ({}, 0 , 0 ) as sock_info :
261+ # Simulate a closed socket without telling the SocketInfo it's
262+ # closed.
263+ sock_info .sock .close ()
264+
265+ # Swap pool's address with a bad one.
266+ address , cx_pool .address = cx_pool .address , ('foo.com' , 1234 )
267+ with self .assertRaises (socket .error ):
268+ with cx_pool .get_socket ({}, 0 , 0 ):
269+ pass
270+
271+ # Back to normal, semaphore was correctly released.
272+ cx_pool .address = address
273+ with cx_pool .get_socket ({}, 0 , 0 , checkout = True ):
274+ pass
275+
230276 def test_pool_with_fork (self ):
231277 # Test that separate MongoClients have separate Pools, and that the
232278 # driver can create a new MongoClient after forking
@@ -438,11 +484,7 @@ def f():
438484 def test_max_pool_size_with_connection_failure (self ):
439485 # The pool acquires its semaphore before attempting to connect; ensure
440486 # it releases the semaphore on connection failure.
441- class TestPool (Pool ):
442- def connect (self ):
443- raise socket .error ()
444-
445- test_pool = TestPool (
487+ test_pool = Pool (
446488 ('example.com' , 27017 ),
447489 PoolOptions (
448490 max_pool_size = 1 ,
0 commit comments