Skip to content

Commit e59fe85

Browse files
committed
Move auth_provider validation to attr setter
1 parent 33c523b commit e59fe85

File tree

2 files changed

+48
-23
lines changed

2 files changed

+48
-23
lines changed

cassandra/cluster.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -198,18 +198,41 @@ class Cluster(object):
198198
Setting this to :const:`False` disables compression.
199199
"""
200200

201-
auth_provider = None
202-
"""
203-
When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
204-
be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
205-
such ass :class:`~.PlainTextAuthProvider`.
201+
_auth_provider = None
202+
_auth_provider_callable = None
206203

207-
When :attr:`~.Cluster.protocol_version` is 1, this should be
208-
a function that accepts one argument, the IP address of a node,
209-
and returns a dict of credentials for that node.
204+
@property
205+
def auth_provider(self):
206+
"""
207+
When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
208+
be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
209+
such ass :class:`~.PlainTextAuthProvider`.
210210
211-
When not using authentication, this should be left as :const:`None`.
212-
"""
211+
When :attr:`~.Cluster.protocol_version` is 1, this should be
212+
a function that accepts one argument, the IP address of a node,
213+
and returns a dict of credentials for that node.
214+
215+
When not using authentication, this should be left as :const:`None`.
216+
"""
217+
return self._auth_provider
218+
219+
@auth_provider.setter # noqa
220+
def auth_provider(self, value):
221+
if not value:
222+
self._auth_provider = value
223+
return
224+
225+
try:
226+
self._auth_provider_callable = value.new_authenticator
227+
except AttributeError:
228+
if self.protocol_version > 1:
229+
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
230+
"interface when protocol_version >= 2")
231+
elif not callable(value):
232+
raise TypeError("auth_provider must be callable when protocol_version == 1")
233+
self._auth_provider_callable = value
234+
235+
self._auth_provider = value
213236

214237
load_balancing_policy = None
215238
"""
@@ -339,15 +362,8 @@ def __init__(self,
339362
self.contact_points = contact_points
340363
self.port = port
341364
self.compression = compression
342-
343-
if auth_provider is not None:
344-
if not hasattr(auth_provider, 'new_authenticator'):
345-
if protocol_version > 1:
346-
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
347-
"interface when protocol_version >= 2")
348-
self.auth_provider = auth_provider
349-
else:
350-
self.auth_provider = auth_provider.new_authenticator
365+
self.protocol_version = protocol_version
366+
self.auth_provider = auth_provider
351367

352368
if load_balancing_policy is not None:
353369
if isinstance(load_balancing_policy, type):
@@ -381,7 +397,6 @@ def __init__(self,
381397
self.ssl_options = ssl_options
382398
self.sockopts = sockopts
383399
self.cql_version = cql_version
384-
self.protocol_version = protocol_version
385400
self.max_schema_agreement_wait = max_schema_agreement_wait
386401
self.control_connection_timeout = control_connection_timeout
387402

@@ -492,8 +507,8 @@ def _make_connection_factory(self, host, *args, **kwargs):
492507
return partial(self.connection_class.factory, host.address, *args, **kwargs)
493508

494509
def _make_connection_kwargs(self, address, kwargs_dict):
495-
if self.auth_provider:
496-
kwargs_dict['authenticator'] = self.auth_provider(address)
510+
if self._auth_provider_callable:
511+
kwargs_dict['authenticator'] = self._auth_provider_callable(address)
497512

498513
kwargs_dict['port'] = self.port
499514
kwargs_dict['compression'] = self.compression

tests/integration/standard/test_cluster.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,18 @@ def test_auth_provider_is_callable(self):
117117
"""
118118
Ensure that auth_providers are always callable
119119
"""
120+
self.assertRaises(TypeError, Cluster, auth_provider=1, protocol_version=1)
121+
c = Cluster(protocol_version=1)
122+
self.assertRaises(TypeError, setattr, c, 'auth_provider', 1)
120123

121-
self.assertRaises(ValueError, Cluster, auth_provider=1)
124+
def test_v2_auth_provider(self):
125+
"""
126+
Check for v2 auth_provider compliance
127+
"""
128+
bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'}
129+
self.assertRaises(TypeError, Cluster, auth_provider=bad_auth_provider, protocol_version=2)
130+
c = Cluster(protocol_version=2)
131+
self.assertRaises(TypeError, setattr, c, 'auth_provider', bad_auth_provider)
122132

123133
def test_conviction_policy_factory_is_callable(self):
124134
"""

0 commit comments

Comments
 (0)