@@ -198,18 +198,41 @@ class Cluster(object):
198198 Setting this to :const:`False` disables compression.
199199 """
200200
201- auth_provider = None
202- """
203- When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
204- be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
205- such ass :class:`~.PlainTextAuthProvider`.
201+ _auth_provider = None
202+ _auth_provider_callable = None
206203
207- When :attr:`~.Cluster.protocol_version` is 1, this should be
208- a function that accepts one argument, the IP address of a node,
209- and returns a dict of credentials for that node.
204+ @property
205+ def auth_provider (self ):
206+ """
207+ When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
208+ be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
209+ such ass :class:`~.PlainTextAuthProvider`.
210210
211- When not using authentication, this should be left as :const:`None`.
212- """
211+ When :attr:`~.Cluster.protocol_version` is 1, this should be
212+ a function that accepts one argument, the IP address of a node,
213+ and returns a dict of credentials for that node.
214+
215+ When not using authentication, this should be left as :const:`None`.
216+ """
217+ return self ._auth_provider
218+
219+ @auth_provider .setter # noqa
220+ def auth_provider (self , value ):
221+ if not value :
222+ self ._auth_provider = value
223+ return
224+
225+ try :
226+ self ._auth_provider_callable = value .new_authenticator
227+ except AttributeError :
228+ if self .protocol_version > 1 :
229+ raise TypeError ("auth_provider must implement the cassandra.auth.AuthProvider "
230+ "interface when protocol_version >= 2" )
231+ elif not callable (value ):
232+ raise TypeError ("auth_provider must be callable when protocol_version == 1" )
233+ self ._auth_provider_callable = value
234+
235+ self ._auth_provider = value
213236
214237 load_balancing_policy = None
215238 """
@@ -339,15 +362,8 @@ def __init__(self,
339362 self .contact_points = contact_points
340363 self .port = port
341364 self .compression = compression
342-
343- if auth_provider is not None :
344- if not hasattr (auth_provider , 'new_authenticator' ):
345- if protocol_version > 1 :
346- raise TypeError ("auth_provider must implement the cassandra.auth.AuthProvider "
347- "interface when protocol_version >= 2" )
348- self .auth_provider = auth_provider
349- else :
350- self .auth_provider = auth_provider .new_authenticator
365+ self .protocol_version = protocol_version
366+ self .auth_provider = auth_provider
351367
352368 if load_balancing_policy is not None :
353369 if isinstance (load_balancing_policy , type ):
@@ -381,7 +397,6 @@ def __init__(self,
381397 self .ssl_options = ssl_options
382398 self .sockopts = sockopts
383399 self .cql_version = cql_version
384- self .protocol_version = protocol_version
385400 self .max_schema_agreement_wait = max_schema_agreement_wait
386401 self .control_connection_timeout = control_connection_timeout
387402
@@ -492,8 +507,8 @@ def _make_connection_factory(self, host, *args, **kwargs):
492507 return partial (self .connection_class .factory , host .address , * args , ** kwargs )
493508
494509 def _make_connection_kwargs (self , address , kwargs_dict ):
495- if self .auth_provider :
496- kwargs_dict ['authenticator' ] = self .auth_provider (address )
510+ if self ._auth_provider_callable :
511+ kwargs_dict ['authenticator' ] = self ._auth_provider_callable (address )
497512
498513 kwargs_dict ['port' ] = self .port
499514 kwargs_dict ['compression' ] = self .compression
0 commit comments