@@ -217,14 +217,15 @@ class PoolOptions(object):
217217 '__connect_timeout' , '__socket_timeout' ,
218218 '__wait_queue_timeout' , '__wait_queue_multiple' ,
219219 '__ssl_context' , '__ssl_match_hostname' , '__socket_keepalive' ,
220- '__event_listeners' , '__appname' , '__metadata' )
220+ '__event_listeners' , '__appname' , '__metadata' ,
221+ '__handshake_callback' )
221222
222223 def __init__ (self , max_pool_size = 100 , min_pool_size = 0 ,
223224 max_idle_time_ms = None , connect_timeout = None ,
224225 socket_timeout = None , wait_queue_timeout = None ,
225226 wait_queue_multiple = None , ssl_context = None ,
226227 ssl_match_hostname = True , socket_keepalive = False ,
227- event_listeners = None , appname = None ):
228+ event_listeners = None , appname = None , handshake_callback = None ):
228229
229230 self .__max_pool_size = max_pool_size
230231 self .__min_pool_size = min_pool_size
@@ -242,6 +243,27 @@ def __init__(self, max_pool_size=100, min_pool_size=0,
242243 if appname :
243244 self .__metadata ['application' ] = {'name' : appname }
244245
246+ self .__handshake_callback = handshake_callback
247+
248+ def with_options (self , ** kwargs ):
249+ options = {
250+ 'max_pool_size' : self .max_pool_size ,
251+ 'min_pool_size' : self .min_pool_size ,
252+ 'max_idle_time_ms' : self .max_idle_time_ms ,
253+ 'connect_timeout' : self .connect_timeout ,
254+ 'socket_timeout' : self .socket_timeout ,
255+ 'wait_queue_timeout' : self .wait_queue_timeout ,
256+ 'wait_queue_multiple' : self .wait_queue_multiple ,
257+ 'ssl_context' : self .ssl_context ,
258+ 'ssl_match_hostname' : self .ssl_match_hostname ,
259+ 'socket_keepalive' : self .socket_keepalive ,
260+ 'event_listeners' : self .event_listeners ,
261+ 'appname' : self .appname ,
262+ 'handshake_callback' : self .handshake_callback }
263+
264+ options .update (kwargs )
265+ return PoolOptions (** options )
266+
245267 @property
246268 def max_pool_size (self ):
247269 """The maximum allowable number of concurrent connections to each
@@ -335,6 +357,11 @@ def metadata(self):
335357 """
336358 return self .__metadata .copy ()
337359
360+ @property
361+ def handshake_callback (self ):
362+ """Receives an ismaster reply and updates the topology."""
363+ return self .__handshake_callback
364+
338365
339366class SocketInfo (object ):
340367 """Store a socket with some metadata.
@@ -746,6 +773,8 @@ def connect(self):
746773 ('ismaster' , 1 ),
747774 ('client' , self .opts .metadata )
748775 ])
776+
777+ start = _time ()
749778 ismaster = IsMaster (
750779 command (sock ,
751780 'admin' ,
@@ -754,13 +783,20 @@ def connect(self):
754783 False ,
755784 ReadPreference .PRIMARY ,
756785 DEFAULT_CODEC_OPTIONS ))
786+
787+ # Can raise ConnectionFailure.
788+ self ._handshake_callback (ismaster , _time () - start )
757789 else :
758790 ismaster = None
759791 return SocketInfo (sock , self , ismaster , self .address )
760792 except socket .error as error :
761793 if sock is not None :
762794 sock .close ()
763795 _raise_connection_failure (self .address , error )
796+ except :
797+ if sock is not None :
798+ sock .close ()
799+ raise
764800
765801 @contextlib .contextmanager
766802 def get_socket (self , all_credentials , checkout = False ):
@@ -889,6 +925,14 @@ def _check(self, sock_info):
889925 else :
890926 return self .connect ()
891927
928+ def _handshake_callback (self , ismaster , round_trip_time ):
929+ callback = self .opts .handshake_callback
930+ if callback :
931+ kept = callback (self .address , ismaster , round_trip_time )
932+ if not kept :
933+ _raise_connection_failure (
934+ self .address , "server removed from topology" )
935+
892936 def _raise_wait_queue_timeout (self ):
893937 raise ConnectionFailure (
894938 'Timed out waiting for socket from pool with max_size %r and'
0 commit comments