Skip to content

Commit 3854ba4

Browse files
committed
Added automatic per-socket authentication of connections when DB credentials are provided to Connection
1 parent 0b548e6 commit 3854ba4

File tree

4 files changed

+282
-28
lines changed

4 files changed

+282
-28
lines changed

pymongo/collection.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ def insert(self, doc_or_docs,
267267
safe = True
268268
self.__database.connection._send_message(
269269
message.insert(self.__full_name, docs,
270-
check_keys, safe, kwargs), safe)
270+
check_keys, safe, kwargs), safe,
271+
collection_name=self.__full_name)
271272

272273
ids = [doc.get("_id", None) for doc in docs]
273274
return return_one and ids[0] or ids
@@ -360,7 +361,8 @@ def update(self, spec, document, upsert=False, manipulate=False,
360361

361362
return self.__database.connection._send_message(
362363
message.update(self.__full_name, upsert, multi,
363-
spec, document, safe, kwargs), safe)
364+
spec, document, safe, kwargs), safe,
365+
collection_name=self.__full_name)
364366

365367
def drop(self):
366368
"""Alias for :meth:`~pymongo.database.Database.drop_collection`.
@@ -433,7 +435,8 @@ def remove(self, spec_or_id=None, safe=False, **kwargs):
433435
safe = True
434436

435437
return self.__database.connection._send_message(
436-
message.delete(self.__full_name, spec_or_id, safe, kwargs), safe)
438+
message.delete(self.__full_name, spec_or_id, safe, kwargs), safe,
439+
collection_name=self.__full_name)
437440

438441
def find_one(self, spec_or_id=None, *args, **kwargs):
439442
"""Get a single document from the database.

pymongo/connection.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,11 @@ def return_socket(self):
172172
self.sock[1].close()
173173
self.sock = None
174174

175+
def socket_ids(self):
176+
return [id(sock) for sock in self.sockets]
175177

176-
class Connection(object): # TODO support auth for pooling
178+
179+
class Connection(object):
177180
"""Connection to MongoDB.
178181
"""
179182

@@ -304,10 +307,15 @@ def __init__(self, host=None, port=None, pool_size=None,
304307
if _connect:
305308
self.__find_master()
306309

310+
# cache of auth username/password credential keyed by DB name
311+
self.__auth_credentials = {}
312+
self.__sock_auths_by_id = {}
307313
if username:
308314
database = database or "admin"
309315
if not self[database].authenticate(username, password):
310316
raise ConfigurationError("authentication failed")
317+
# Add database auth credentials for auto-auth later
318+
self.add_db_auth(database, username, password)
311319

312320
@classmethod
313321
def from_uri(cls, uri="mongodb://localhost", **connection_args):
@@ -614,7 +622,33 @@ def __check_response_to_last_error(self, response):
614622
else:
615623
raise OperationFailure(error["err"])
616624

617-
def _send_message(self, message, with_last_error=False):
625+
def _authenticate_socket_for_db(self, sock, db_name):
626+
# Periodically remove cached auth flags of expired sockets
627+
if len(self.__sock_auths_by_id) > self.pool_size:
628+
cached_sock_ids = self.__sock_auths_by_id.keys()
629+
current_sock_ids = self.__pool.socket_ids()
630+
for sock_id in cached_sock_ids:
631+
if not sock_id in current_sock_ids:
632+
del(self.__sock_auths_by_id[sock_id])
633+
if not self.__auth_credentials:
634+
return # No credentials for any database
635+
sock_id = id(sock)
636+
if db_name in self.__sock_auths_by_id.get(sock_id, {}):
637+
return # Already authenticated for database
638+
if not self.has_db_auth(db_name):
639+
return # No credentials for database
640+
username, password = self.get_db_auth(db_name)
641+
if not self[db_name].authenticate(username, password):
642+
import pdb; pdb.set_trace()
643+
raise ConfigurationError("authentication to db %s failed for %s"
644+
% (db_name, username))
645+
if not sock_id in self.__sock_auths_by_id:
646+
self.__sock_auths_by_id[sock_id] = {}
647+
self.__sock_auths_by_id[sock_id][db_name] = 1
648+
return True
649+
650+
def _send_message(self, message, with_last_error=False,
651+
collection_name=None):
618652
"""Say something to Mongo.
619653
620654
Raises ConnectionFailure if the message cannot be sent. Raises
@@ -630,6 +664,14 @@ def _send_message(self, message, with_last_error=False):
630664
"""
631665
sock = self.__socket()
632666
try:
667+
# Always authenticate for admin database, if possible
668+
if self._authenticate_socket_for_db(sock, 'admin'):
669+
pass # No need for futher auth with admin login
670+
elif collection_name and collection_name.split('.') >= 1:
671+
# Authenticate for specific database
672+
db_name = collection_name.split('.')[0]
673+
self._authenticate_socket_for_db(sock, db_name)
674+
633675
(request_id, data) = message
634676
sock.sendall(data)
635677
# Safe mode. We pack the message together with a lastError
@@ -886,3 +928,29 @@ def __iter__(self):
886928

887929
def next(self):
888930
raise TypeError("'Connection' object is not iterable")
931+
932+
def add_db_auth(self, db_name, username, password):
933+
if not username or not isinstance(username, basestring):
934+
raise ConfigurationError('invalid username')
935+
if not password or not isinstance(password, basestring):
936+
raise ConfigurationError('invalid password')
937+
self.__auth_credentials[db_name] = (username, password)
938+
939+
def has_db_auth(self, db_name):
940+
return db_name in self.__auth_credentials
941+
942+
def get_db_auth(self, db_name):
943+
if self.has_db_auth(db_name):
944+
return self.__auth_credentials[db_name]
945+
return None
946+
947+
def remove_db_auth(self, db_name):
948+
if self.has_db_auth(db_name):
949+
del(self.__auth_credentials[db_name])
950+
# Force close any existing sockets to flush auths
951+
self.disconnect()
952+
953+
def clear_db_auths(self):
954+
self.__auth_credentials = {} # Forget all credentials
955+
# Force close any existing sockets to flush auths
956+
self.disconnect()

test/test_connection.py

Lines changed: 127 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -282,28 +282,46 @@ def test_from_uri(self):
282282
c.pymongo_test.system.users.remove({})
283283

284284
c.admin.add_user("admin", "pass")
285-
c.pymongo_test.add_user("user", "pass")
286-
287-
self.assertRaises(ConfigurationError, Connection,
288-
"mongodb://foo:bar@%s:%s" % (self.host, self.port))
289-
self.assertRaises(ConfigurationError, Connection,
290-
"mongodb://admin:bar@%s:%s" % (self.host, self.port))
291-
self.assertRaises(ConfigurationError, Connection,
292-
"mongodb://user:pass@%s:%s" % (self.host, self.port))
293-
Connection("mongodb://admin:pass@%s:%s" % (self.host, self.port))
294-
295-
self.assertRaises(ConfigurationError, Connection,
296-
"mongodb://admin:pass@%s:%s/pymongo_test" %
297-
(self.host, self.port))
298-
self.assertRaises(ConfigurationError, Connection,
299-
"mongodb://user:foo@%s:%s/pymongo_test" %
300-
(self.host, self.port))
301-
Connection("mongodb://user:pass@%s:%s/pymongo_test" %
302-
(self.host, self.port))
303-
304-
self.assert_(Connection("mongodb://%s:%s" %
305-
(self.host, self.port),
306-
slave_okay=True).slave_okay)
285+
try:
286+
# Not yet logged in
287+
try:
288+
c.admin.system.users.find_one()
289+
# If we get this far auth must not be enabled in server
290+
raise SkipTest()
291+
except OperationFailure:
292+
pass
293+
294+
# Now we log in
295+
c.admin.authenticate("admin", "pass")
296+
297+
c.pymongo_test.add_user("user", "pass")
298+
299+
self.assertRaises(ConfigurationError, Connection,
300+
"mongodb://foo:bar@%s:%s" % (self.host, self.port))
301+
self.assertRaises(ConfigurationError, Connection,
302+
"mongodb://admin:bar@%s:%s" % (self.host, self.port))
303+
self.assertRaises(ConfigurationError, Connection,
304+
"mongodb://user:pass@%s:%s" % (self.host, self.port))
305+
Connection("mongodb://admin:pass@%s:%s" % (self.host, self.port))
306+
307+
self.assertRaises(ConfigurationError, Connection,
308+
"mongodb://admin:pass@%s:%s/pymongo_test" %
309+
(self.host, self.port))
310+
self.assertRaises(ConfigurationError, Connection,
311+
"mongodb://user:foo@%s:%s/pymongo_test" %
312+
(self.host, self.port))
313+
Connection("mongodb://user:pass@%s:%s/pymongo_test" %
314+
(self.host, self.port))
315+
316+
self.assert_(Connection("mongodb://%s:%s" %
317+
(self.host, self.port),
318+
slave_okay=True).slave_okay)
319+
finally:
320+
# Remove auth users from databases
321+
c = Connection(self.host, self.port)
322+
c.admin.authenticate("admin", "pass")
323+
c.admin.system.users.remove({})
324+
c.pymongo_test.system.users.remove({})
307325

308326
def test_fork(self):
309327
"""Test using a connection before and after a fork.
@@ -432,6 +450,93 @@ def test_tz_aware(self):
432450
self.assertEqual(aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None),
433451
naive.pymongo_test.test.find_one()["x"])
434452

453+
def test_auto_db_authentication(self):
454+
conn = Connection(self.host, self.port)
455+
456+
# Setup admin user
457+
conn.admin.system.users.remove({})
458+
conn.admin.add_user("admin-user", "password")
459+
conn.admin.authenticate("admin-user", "password")
460+
461+
try: # try/finally to ensure we remove admin user
462+
# Setup test database user
463+
conn.pymongo_test.system.users.remove({})
464+
conn.pymongo_test.add_user("test-user", "password")
465+
466+
conn.pymongo_test.drop_collection("test")
467+
468+
self.assertRaises(TypeError, conn.add_db_auth, "", "password")
469+
self.assertRaises(TypeError, conn.add_db_auth, 5, "password")
470+
self.assertRaises(TypeError, conn.add_db_auth, "test-user", "")
471+
self.assertRaises(TypeError, conn.add_db_auth, "test-user", 5)
472+
473+
# Not yet logged in
474+
conn = Connection(self.host, self.port)
475+
try:
476+
conn.admin.system.users.find_one()
477+
# If we get this far auth must not be enabled in server
478+
raise SkipTest()
479+
except OperationFailure:
480+
pass
481+
482+
# Not yet logged in
483+
conn = Connection(self.host, self.port)
484+
self.assertRaises(OperationFailure, conn.pymongo_test.test.count)
485+
self.assertFalse(conn.has_db_auth('admin'))
486+
self.assertEquals(None, conn.get_db_auth('admin'))
487+
488+
# Admin log in via URI
489+
conn = Connection('admin-user:password@%s' % self.host, self.port)
490+
self.assertTrue(conn.has_db_auth('admin'))
491+
self.assertEquals('admin-user', conn.get_db_auth('admin')[0])
492+
conn.admin.system.users.find()
493+
conn.pymongo_test.test.insert({'_id':1, 'test':'data'}, safe=True)
494+
self.assertEquals(1, conn.pymongo_test.test.find({'_id':1}).count())
495+
conn.pymongo_test.test.remove({'_id':1})
496+
497+
# Clear and reset database authentication for all sockets
498+
conn.clear_db_auths()
499+
self.assertFalse(conn.has_db_auth('admin'))
500+
self.assertRaises(OperationFailure, conn.pymongo_test.test.count)
501+
502+
# Admin log in via add_db_auth
503+
conn = Connection(self.host, self.port)
504+
conn.admin.system.users.find()
505+
conn.add_db_auth('admin', 'admin-user', 'password')
506+
conn.pymongo_test.test.insert({'_id':2, 'test':'data'}, safe=True)
507+
self.assertEquals(1, conn.pymongo_test.test.find({'_id':2}).count())
508+
conn.pymongo_test.test.remove({'_id':2})
509+
510+
# Remove database authentication for specific database
511+
self.assertTrue(conn.has_db_auth('admin'))
512+
conn.remove_db_auth('admin')
513+
self.assertFalse(conn.has_db_auth('admin'))
514+
self.assertRaises(OperationFailure, conn.pymongo_test.test.count)
515+
516+
# Incorrect admin credentials
517+
conn = Connection(self.host, self.port)
518+
conn.add_db_auth('admin', 'admin-user', 'wrong-password')
519+
self.assertRaises(OperationFailure, conn.pymongo_test.test.count)
520+
521+
# Database-specific log in
522+
conn = Connection(self.host, self.port)
523+
conn.add_db_auth('pymongo_test', 'test-user', 'password')
524+
self.assertRaises(OperationFailure, conn.admin.system.users.find_one)
525+
conn.pymongo_test.test.insert({'_id':3, 'test':'data'}, safe=True)
526+
self.assertEquals(1, conn.pymongo_test.test.find({'_id':3}).count())
527+
conn.pymongo_test.test.remove({'_id':3})
528+
529+
# Incorrect database credentials
530+
conn = Connection(self.host, self.port)
531+
conn.add_db_auth('pymongo_test', 'wrong-user', 'password')
532+
self.assertRaises(OperationFailure, conn.pymongo_test.test.find_one)
533+
finally:
534+
# Remove auth users from databases
535+
conn = Connection(self.host, self.port)
536+
conn.admin.authenticate("admin-user", "password")
537+
conn.admin.system.users.remove({})
538+
conn.pymongo_test.system.users.remove({})
539+
435540

436541
if __name__ == "__main__":
437542
unittest.main()

0 commit comments

Comments
 (0)