Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add assertion param to backed.authenticate and is_authorized functions
  • Loading branch information
lucyeun-alation committed Apr 14, 2021
commit fedbd28bae58e9ca7c64571355647d23b51c5a22
6 changes: 3 additions & 3 deletions djangosaml2/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _get_attribute_value(self, django_field: str, attributes: dict, attribute_ma
'value is missing. Probably the user '
'session is expired.')

def authenticate(self, request, session_info=None, attribute_mapping=None, create_unknown_user=True, **kwargs):
def authenticate(self, request, session_info=None, attribute_mapping=None, create_unknown_user=True, assertion=None, **kwargs):
if session_info is None or attribute_mapping is None:
logger.info('Session info or attribute mapping are None')
return None
Expand All @@ -121,7 +121,7 @@ def authenticate(self, request, session_info=None, attribute_mapping=None, creat

logger.debug(f'attributes: {attributes}')

if not self.is_authorized(attributes, attribute_mapping, idp_entityid):
if not self.is_authorized(attributes, attribute_mapping, idp_entityid, assertion):
logger.error('Request not authorized')
return None

Expand Down Expand Up @@ -194,7 +194,7 @@ def clean_attributes(self, attributes: dict, idp_entityid: str, **kwargs) -> dic
""" Hook to clean or filter attributes from the SAML response. No-op by default. """
return attributes

def is_authorized(self, attributes: dict, attribute_mapping: dict, idp_entityid: str, **kwargs) -> bool:
def is_authorized(self, attributes: dict, attribute_mapping: dict, idp_entityid: str, assertion: object, **kwargs) -> bool:
""" Hook to allow custom authorization policies based on SAML attributes. True by default. """
return True

Expand Down
3 changes: 2 additions & 1 deletion djangosaml2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def post(self, request, attribute_mapping=None, create_unknown_user=None):
user = auth.authenticate(request=request,
session_info=session_info,
attribute_mapping=attribute_mapping,
create_unknown_user=create_unknown_user)
create_unknown_user=create_unknown_user,
assertion=response.assertion)
if user is None:
logger.warning(
"Could not authenticate user received in SAML Assertion. Session info: %s", session_info)
Expand Down
19 changes: 14 additions & 5 deletions tests/testprofiles/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from django.core.exceptions import ImproperlyConfigured
from django.test import TestCase, override_settings
from djangosaml2.backends import Saml2Backend, set_attribute
from saml2.saml import Assertion

from testprofiles.models import TestUser

Expand Down Expand Up @@ -104,7 +105,7 @@ def test_extract_user_identifier_params_use_nameid_missing(self):
self.assertEqual(lookup_value, None)

def test_is_authorized(self):
self.assertTrue(self.backend.is_authorized({}, {}, ''))
self.assertTrue(self.backend.is_authorized({}, {}, '', None))

def test_clean_attributes(self):
attributes = {'random': 'dummy', 'value': 123}
Expand Down Expand Up @@ -333,9 +334,9 @@ def test_deprecations(self):
class CustomizedBackend(Saml2Backend):
""" Override the available methods with some customized implementation to test customization
"""
def is_authorized(self, attributes, attribute_mapping, idp_entityid: str, **kwargs):
def is_authorized(self, attributes, attribute_mapping, idp_entityid: str, assertion, **kwargs):
''' Allow only staff users from the IDP '''
return attributes.get('is_staff', (None, ))[0] == True
return attributes.get('is_staff', (None, ))[0] == True and getattr(assertion, 'id', None) != None

def clean_attributes(self, attributes: dict, idp_entityid: str, **kwargs) -> dict:
''' Keep only age attribute '''
Expand Down Expand Up @@ -368,9 +369,12 @@ def test_is_authorized(self):
'cn': ('John', ),
'sn': ('Doe', ),
}
self.assertFalse(self.backend.is_authorized(attributes, attribute_mapping, ''))
assertion = Assertion()
self.assertFalse(self.backend.is_authorized(attributes, attribute_mapping, '', assertion))
attributes['is_staff'] = (True, )
self.assertTrue(self.backend.is_authorized(attributes, attribute_mapping, ''))
self.assertFalse(self.backend.is_authorized(attributes, attribute_mapping, '', assertion))
assertion.id = 'abcdefg12345'
self.assertTrue(self.backend.is_authorized(attributes, attribute_mapping, '', assertion))

def test_clean_attributes(self):
attributes = {'random': 'dummy', 'value': 123, 'age': '28'}
Expand All @@ -396,6 +400,7 @@ def test_authenticate(self):
'age': ('28', ),
'is_staff': (True, ),
}
assertion = Assertion(id='abcdefg12345')

self.assertEqual(self.user.age, '')
self.assertEqual(self.user.is_staff, False)
Expand All @@ -409,6 +414,7 @@ def test_authenticate(self):
None,
session_info={'random': 'content'},
attribute_mapping=attribute_mapping,
assertion=assertion,
)
self.assertIsNone(user)

Expand All @@ -417,6 +423,7 @@ def test_authenticate(self):
None,
session_info={'ava': attributes, 'issuer': 'dummy_entity_id'},
attribute_mapping=attribute_mapping,
assertion=assertion,
)
self.assertIsNone(user)

Expand All @@ -425,6 +432,7 @@ def test_authenticate(self):
None,
session_info={'ava': attributes, 'issuer': 'dummy_entity_id'},
attribute_mapping=attribute_mapping,
assertion=assertion,
)
self.assertIsNone(user)

Expand All @@ -433,6 +441,7 @@ def test_authenticate(self):
None,
session_info={'ava': attributes, 'issuer': 'dummy_entity_id'},
attribute_mapping=attribute_mapping,
assertion=assertion,
)

self.assertEqual(user, self.user)
Expand Down