Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions djangosaml2/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _get_attribute_value(self, django_field: str, attributes: dict, attribute_ma
'value is missing. Probably the user '
'session is expired.')

def authenticate(self, request, session_info=None, attribute_mapping=None, create_unknown_user=True, **kwargs):
def authenticate(self, request, session_info=None, attribute_mapping=None, create_unknown_user=True, assertion_info=None, **kwargs):
if session_info is None or attribute_mapping is None:
logger.info('Session info or attribute mapping are None')
return None
Expand All @@ -121,7 +121,7 @@ def authenticate(self, request, session_info=None, attribute_mapping=None, creat

logger.debug(f'attributes: {attributes}')

if not self.is_authorized(attributes, attribute_mapping, idp_entityid):
if not self.is_authorized(attributes, attribute_mapping, idp_entityid, assertion_info):
logger.error('Request not authorized')
return None

Expand Down Expand Up @@ -194,7 +194,7 @@ def clean_attributes(self, attributes: dict, idp_entityid: str, **kwargs) -> dic
""" Hook to clean or filter attributes from the SAML response. No-op by default. """
return attributes

def is_authorized(self, attributes: dict, attribute_mapping: dict, idp_entityid: str, **kwargs) -> bool:
def is_authorized(self, attributes: dict, attribute_mapping: dict, idp_entityid: str, assertion_info: dict, **kwargs) -> bool:
""" Hook to allow custom authorization policies based on SAML attributes. True by default. """
return True

Expand Down
14 changes: 13 additions & 1 deletion djangosaml2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
StatusNoAuthnContext, StatusRequestDenied,
UnsolicitedResponse)
from saml2.s_utils import UnsupportedBinding
from saml2.saml import SCM_BEARER
from saml2.samlp import AuthnRequest
from saml2.sigver import MissingKey
from saml2.validate import ResponseLifetimeExceed, ToEarly
Expand All @@ -56,6 +57,7 @@
get_idp_sso_supported_bindings, get_location,
validate_referral_url)


logger = logging.getLogger('djangosaml2')


Expand Down Expand Up @@ -420,6 +422,15 @@ def post(self, request, attribute_mapping=None, create_unknown_user=None):
# authenticate the remote user
session_info = response.session_info()

# assertion_info
assertion = response.assertion
assertion_info = {}
for sc in assertion.subject.subject_confirmation:
if sc.method == SCM_BEARER:
assertion_not_on_or_after = sc.subject_confirmation_data.not_on_or_after
assertion_info = {'assertion_id': assertion.id, 'not_on_or_after': assertion_not_on_or_after}
break

Comment on lines +426 to +433
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SAML doc says
The service provider MUST ensure that bearer assertions are not replayed, by maintaining the set of used ID values for the length of time for which the assertion would be considered valid based on the NotOnOrAfter attribute in the <SubjectConfirmationData>.

The bearer <SubjectConfirmation> element described above MUST contain a <SubjectConfirmationData> element that contains a Recipient attribute containing the service provider's assertion consumer service URL and a NotOnOrAfter attribute that limits the window during which the assertion can be delivered.

if callable(attribute_mapping):
attribute_mapping = attribute_mapping()
if callable(create_unknown_user):
Expand All @@ -430,7 +441,8 @@ def post(self, request, attribute_mapping=None, create_unknown_user=None):
user = auth.authenticate(request=request,
session_info=session_info,
attribute_mapping=attribute_mapping,
create_unknown_user=create_unknown_user)
create_unknown_user=create_unknown_user,
assertion_info=assertion_info)
if user is None:
logger.warning(
"Could not authenticate user received in SAML Assertion. Session info: %s", session_info)
Expand Down
25 changes: 25 additions & 0 deletions docs/source/contents/setup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,31 @@ setting::

SAML_CONFIG_LOADER = 'python.path.to.your.callable'

Bearer Assertion Replay Attack Prevention
==================================
In SAML standard doc, section 4.1.4.5 it states

The service provider MUST ensure that bearer assertions are not replayed, by maintaining the set of used ID values for the length of time for which the assertion would be considered valid based on the NotOnOrAfter attribute in the <SubjectConfirmationData>

djangosaml2 provides a hook 'is_authorized' for the SP to store assertion IDs and implement replay prevention with your choice of storage.
::

def is_authorized(self, attributes: dict, attribute_mapping: dict, idp_entityid: str, assertion: object, **kwargs) -> bool:
if not assertion:
return True

# Get your choice of storage
cache_storage = storage.get_cache()
assertion_id = assertion.get('assertion_id')

if cache.get(assertion_id):
logger.warn("Received SAMLResponse assertion has been already used.")
return False

expiration_time = assertion.get('not_on_or_after')
time_delta = isoparse(expiration_time) - datetime.now(timezone.utc)
cache_storage.set(assertion_id, 'True', ex=time_delta)
return True

Users, attributes and account linking
-------------------------------------
Expand Down
24 changes: 19 additions & 5 deletions tests/testprofiles/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_extract_user_identifier_params_use_nameid_missing(self):
self.assertEqual(lookup_value, None)

def test_is_authorized(self):
self.assertTrue(self.backend.is_authorized({}, {}, ''))
self.assertTrue(self.backend.is_authorized({}, {}, '', {}))

def test_clean_attributes(self):
attributes = {'random': 'dummy', 'value': 123}
Expand Down Expand Up @@ -333,9 +333,9 @@ def test_deprecations(self):
class CustomizedBackend(Saml2Backend):
""" Override the available methods with some customized implementation to test customization
"""
def is_authorized(self, attributes, attribute_mapping, idp_entityid: str, **kwargs):
def is_authorized(self, attributes, attribute_mapping, idp_entityid: str, assertion_info, **kwargs):
''' Allow only staff users from the IDP '''
return attributes.get('is_staff', (None, ))[0] == True
return attributes.get('is_staff', (None, ))[0] == True and assertion_info.get('assertion_id', None) != None

def clean_attributes(self, attributes: dict, idp_entityid: str, **kwargs) -> dict:
''' Keep only age attribute '''
Expand Down Expand Up @@ -368,9 +368,15 @@ def test_is_authorized(self):
'cn': ('John', ),
'sn': ('Doe', ),
}
self.assertFalse(self.backend.is_authorized(attributes, attribute_mapping, ''))
assertion_info = {
'assertion_id': None,
'not_on_or_after': None,
}
self.assertFalse(self.backend.is_authorized(attributes, attribute_mapping, '', assertion_info))
attributes['is_staff'] = (True, )
self.assertTrue(self.backend.is_authorized(attributes, attribute_mapping, ''))
self.assertFalse(self.backend.is_authorized(attributes, attribute_mapping, '', assertion_info))
assertion_info['assertion_id'] = 'abcdefg12345'
self.assertTrue(self.backend.is_authorized(attributes, attribute_mapping, '', assertion_info))

def test_clean_attributes(self):
attributes = {'random': 'dummy', 'value': 123, 'age': '28'}
Expand All @@ -396,6 +402,10 @@ def test_authenticate(self):
'age': ('28', ),
'is_staff': (True, ),
}
assertion_info = {
'assertion_id': 'abcdefg12345',
'not_on_or_after': '',
}

self.assertEqual(self.user.age, '')
self.assertEqual(self.user.is_staff, False)
Expand All @@ -409,6 +419,7 @@ def test_authenticate(self):
None,
session_info={'random': 'content'},
attribute_mapping=attribute_mapping,
assertion_info=assertion_info,
)
self.assertIsNone(user)

Expand All @@ -417,6 +428,7 @@ def test_authenticate(self):
None,
session_info={'ava': attributes, 'issuer': 'dummy_entity_id'},
attribute_mapping=attribute_mapping,
assertion_info=assertion_info,
)
self.assertIsNone(user)

Expand All @@ -425,6 +437,7 @@ def test_authenticate(self):
None,
session_info={'ava': attributes, 'issuer': 'dummy_entity_id'},
attribute_mapping=attribute_mapping,
assertion_info=assertion_info,
)
self.assertIsNone(user)

Expand All @@ -433,6 +446,7 @@ def test_authenticate(self):
None,
session_info={'ava': attributes, 'issuer': 'dummy_entity_id'},
attribute_mapping=attribute_mapping,
assertion_info=assertion_info,
)

self.assertEqual(user, self.user)
Expand Down