Skip to content

Commit 41096d5

Browse files
committed
PYTHON-1332 - Check current user owns session
1 parent 410027c commit 41096d5

File tree

6 files changed

+85
-27
lines changed

6 files changed

+85
-27
lines changed

pymongo/bulk.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def execute_command(self, sock_info, generator, write_concern, session):
310310
listeners = self.collection.database.client._event_listeners
311311

312312
with self.collection.database.client._tmp_session(session) as s:
313+
# sock_info.command checks auth, but we use sock_info.write_command.
314+
sock_info.check_session_auth_matches(s)
313315
for run in generator:
314316
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
315317
('ordered', self.ordered)])

pymongo/client_session.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,13 @@ def causally_consistent_reads(self):
6666

6767

6868
class ClientSession(object):
69-
"""A session for ordering sequential operations.
70-
71-
:Parameters:
72-
- `client`: A :class:`~pymongo.mongo_client.MongoClient`.
73-
- `options` (optional): A :class:`SessionOptions` instance.
74-
"""
75-
def __init__(self, client, options=None):
69+
"""A session for ordering sequential operations."""
70+
def __init__(self, client, server_session, options, authset):
71+
# A MongoClient, a _ServerSession, a SessionOptions, and a set.
7672
self._client = client
77-
78-
if options is not None:
79-
self._options = options
80-
else:
81-
self._options = SessionOptions()
82-
83-
# Raises ConfigurationError if sessions are not supported.
84-
self._server_session = client._get_server_session()
73+
self._server_session = server_session
74+
self._options = options
75+
self._authset = authset
8576

8677
def end_session(self):
8778
"""Finish this session.

pymongo/message.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ def use_command(self, sock_info, exhaust):
274274
'Specifying a collation is unsupported with a max wire '
275275
'version of %d.' % (sock_info.max_wire_version,))
276276

277+
sock_info.check_session_auth_matches(self.session)
278+
277279
return use_find_cmd
278280

279281
def as_command(self):
@@ -345,6 +347,7 @@ def __init__(self, db, coll, ntoreturn, cursor_id, codec_options, session,
345347
self.session = session
346348

347349
def use_command(self, sock_info, exhaust):
350+
sock_info.check_session_auth_matches(self.session)
348351
return sock_info.max_wire_version >= 4 and not exhaust
349352

350353
def as_command(self):

pymongo/mongo_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,12 +1219,15 @@ def start_session(self, **kwargs):
12191219
# Driver Sessions Spec: "If startSession is called when multiple users
12201220
# are authenticated drivers MUST raise an error with the error message
12211221
# 'Cannot call startSession when multiple users are authenticated.'"
1222-
if len(self.__all_credentials) > 1:
1222+
authset = set(self.__all_credentials.values())
1223+
if len(authset) > 1:
12231224
raise InvalidOperation("Cannot call start_session when"
12241225
" multiple users are authenticated")
12251226

1227+
# Raises ConfigurationError if sessions are not supported.
1228+
server_session = self._get_server_session()
12261229
opts = client_session.SessionOptions(**kwargs)
1227-
return client_session.ClientSession(self, opts)
1230+
return client_session.ClientSession(self, server_session, opts, authset)
12281231

12291232
def _get_server_session(self):
12301233
"""Internal: start or resume a _ServerSession."""

pymongo/pool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class SSLError(socket.error):
3737
from pymongo.errors import (AutoReconnect,
3838
ConnectionFailure,
3939
ConfigurationError,
40+
InvalidOperation,
4041
DocumentTooLarge,
4142
NetworkTimeout,
4243
NotMasterError,
@@ -437,7 +438,7 @@ def command(self, dbname, spec, slave_ok=False,
437438
parse_write_concern_error=False,
438439
collation=None,
439440
session=None):
440-
"""Execute a command or raise ConnectionFailure or OperationFailure.
441+
"""Execute a command or raise an error.
441442
442443
:Parameters:
443444
- `dbname`: name of the database on which to run the command
@@ -455,6 +456,7 @@ def command(self, dbname, spec, slave_ok=False,
455456
- `collation`: The collation for this command.
456457
- `session`: optional ClientSession instance.
457458
"""
459+
self.check_session_auth_matches(session)
458460
if self.max_wire_version < 4 and not read_concern.ok_for_legacy:
459461
raise ConfigurationError(
460462
'read concern level of %s is not valid '
@@ -583,6 +585,12 @@ def authenticate(self, credentials):
583585
auth.authenticate(credentials, self)
584586
self.authset.add(credentials)
585587

588+
def check_session_auth_matches(self, session):
589+
"""Raise error if a ClientSession is logged in as a different user."""
590+
if session and session._authset != self.authset:
591+
raise InvalidOperation('start_session was called while'
592+
' authenticated with different credentials')
593+
586594
def close(self):
587595
self.closed = True
588596
# Avoid exceptions on interpreter shutdown.

test/test_session.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def failed(self, event):
4545
if not event.command_name.startswith('sasl'):
4646
super(SessionTestListener, self).failed(event)
4747

48+
def first_command_started(self):
49+
assert len(self.results['started']) >= 1, (
50+
"No command-started events")
51+
52+
return self.results['started'][0]
53+
4854

4955
def session_ids(client):
5056
return [s.session_id for s in client._topology._session_pool]
@@ -122,16 +128,63 @@ def _test_ops(self, client, *ops, **kwargs):
122128
@client_context.require_auth
123129
@ignore_deprecations
124130
def test_session_authenticate_multiple(self):
131+
listener = SessionTestListener()
125132
# Logged in as root.
126-
client = rs_or_single_client()
127-
client.pymongo_test.add_user('second-user', 'pass')
128-
self.addCleanup(client.pymongo_test.remove_user, 'second-user')
129-
130-
client.pymongo_test.authenticate('second-user', 'pass')
133+
client = rs_or_single_client(event_listeners=[listener])
134+
db = client.pymongo_test
135+
db.add_user('second-user', 'pass')
136+
self.addCleanup(db.remove_user, 'second-user')
137+
db.authenticate('second-user', 'pass')
131138

132139
with self.assertRaises(InvalidOperation):
133140
client.start_session()
134141

142+
# No implicit sessions.
143+
listener.results.clear()
144+
db.collection.find_one()
145+
event = listener.first_command_started()
146+
self.assertNotIn(
147+
'lsid', event.command,
148+
"find_one with multi-auth shouldn't have sent lsid with %s" % (
149+
event.command_name))
150+
151+
# Changing auth invalidates the session. Start as root.
152+
client = rs_or_single_client(event_listeners=[listener])
153+
db = client.pymongo_test
154+
db.collection.insert_many([{} for _ in range(10)])
155+
self.addCleanup(db.collection.drop)
156+
with client.start_session() as s:
157+
listener.results.clear()
158+
cursor = db.collection.find(session=s).batch_size(2)
159+
next(cursor)
160+
event = listener.first_command_started()
161+
self.assertEqual(event.command_name, 'find')
162+
self.assertEqual(
163+
s.session_id, event.command.get('lsid'),
164+
"find sent wrong lsid with %s" % (event.command_name,))
165+
166+
client.admin.logout()
167+
db.authenticate('second-user', 'pass')
168+
169+
err = 'start_session was called while authenticated with' \
170+
' different credentials'
171+
172+
with self.assertRaisesRegex(InvalidOperation, err):
173+
# Auth has changed between find and getMore.
174+
list(cursor)
175+
176+
with self.assertRaisesRegex(InvalidOperation, err):
177+
db.collection.bulk_write([InsertOne({})], session=s)
178+
179+
with self.assertRaisesRegex(InvalidOperation, err):
180+
db.collection_names(session=s)
181+
182+
with self.assertRaisesRegex(InvalidOperation, err):
183+
db.collection.find_one(session=s)
184+
185+
with self.assertRaisesRegex(InvalidOperation, err):
186+
list(db.collection.aggregate([], session=s))
187+
135188
def test_pool_lifo(self):
136189
# "Pool is LIFO" test from Driver Sessions Spec.
137190
a = self.client.start_session()
@@ -380,8 +433,7 @@ def test_cursor(self):
380433
for name, f in ops:
381434
listener.results.clear()
382435
f(session=None)
383-
self.assertGreaterEqual(len(listener.results['started']), 1)
384-
event0 = listener.results['started'][0]
436+
event0 = listener.first_command_started()
385437
self.assertTrue(
386438
'lsid' in event0.command,
387439
"%s sent no lsid with %s" % (
@@ -573,8 +625,7 @@ def test_aggregate_error(self):
573625
with self.assertRaises(OperationFailure):
574626
coll.aggregate([{'$badOperation': {'bar': 1}}])
575627

576-
self.assertEqual(len(listener.results['started']), 1)
577-
event = listener.results['started'][0]
628+
event = listener.first_command_started()
578629
self.assertEqual(event.command_name, 'aggregate')
579630
lsid = event.command['lsid']
580631
# Session was returned to pool despite error.

0 commit comments

Comments
 (0)