Skip to content
15 changes: 11 additions & 4 deletions msal/authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,20 @@ def __init__(self, authority_url, validate_authority=True,
self.proxies = proxies
self.timeout = timeout
canonicalized, self.instance, tenant = canonicalize(authority_url)
tenant_discovery_endpoint = ( # Hard code a V2 pattern as default value
'https://{}/{}/v2.0/.well-known/openid-configuration'
.format(self.instance, tenant))
if validate_authority and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS:
tenant_discovery_endpoint = (
'https://{}/{}{}/.well-known/openid-configuration'.format(
self.instance,
tenant,
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
))
if (tenant != "adfs" and validate_authority
and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS):
tenant_discovery_endpoint = instance_discovery(
canonicalized + "/oauth2/v2.0/authorize",
verify=verify, proxies=proxies, timeout=timeout)
if tenant.lower() == "adfs":
tenant_discovery_endpoint = ("https://{}/adfs/.well-known/openid-configuration"
.format(self.instance))
openid_config = tenant_discovery(
tenant_discovery_endpoint,
verify=verify, proxies=proxies, timeout=timeout)
Expand Down
27 changes: 17 additions & 10 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,25 @@ def add(self, event, now=None):
event, indent=4, sort_keys=True,
default=str, # A workaround when assertion is in bytes in Python 3
))
environment = realm = None
if "token_endpoint" in event:
_, environment, realm = canonicalize(event["token_endpoint"])
response = event.get("response", {})
access_token = response.get("access_token")
refresh_token = response.get("refresh_token")
id_token = response.get("id_token")
id_token_claims = (
decode_id_token(id_token, client_id=event["client_id"])
if id_token else {})
client_info = {}
home_account_id = None
if "client_info" in response:
home_account_id = None # It would remain None in client_credentials flow
if "client_info" in response: # We asked for it, and AAD will provide it
client_info = json.loads(base64decode(response["client_info"]))
home_account_id = "{uid}.{utid}".format(**client_info)
environment = realm = None
if "token_endpoint" in event:
_, environment, realm = canonicalize(event["token_endpoint"])
elif id_token_claims: # This would be an end user on ADFS-direct scenario
client_info["uid"] = id_token_claims.get("sub")
home_account_id = id_token_claims.get("sub")

target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it

with self._lock:
Expand All @@ -148,15 +155,15 @@ def add(self, event, now=None):
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)

if client_info:
decoded_id_token = decode_id_token(
id_token, client_id=event["client_id"]) if id_token else {}
account = {
"home_account_id": home_account_id,
"environment": environment,
"realm": realm,
"local_account_id": decoded_id_token.get(
"oid", decoded_id_token.get("sub")),
"username": decoded_id_token.get("preferred_username"),
"local_account_id": id_token_claims.get(
"oid", id_token_claims.get("sub")),
"username": id_token_claims.get("preferred_username") # AAD
or id_token_claims.get("upn") # ADFS 2019
or "", # The schema does not like null
"authority_type":
self.AuthorityType.ADFS if realm == "adfs"
else self.AuthorityType.MSSTS,
Expand Down
87 changes: 79 additions & 8 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,29 @@ class TokenCacheTestCase(unittest.TestCase):
@staticmethod
def build_id_token(
iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None,
preferred_username="me", **claims):
**claims): # AAD issues "preferred_username", ADFS issues "upn"
return "header.%s.signature" % base64.b64encode(json.dumps(dict({
"iss": iss,
"sub": sub,
"aud": aud,
"exp": exp or (time.time() + 100),
"iat": iat or time.time(),
"preferred_username": preferred_username,
}, **claims)).encode()).decode('utf-8')

@staticmethod
def build_response( # simulate a response from AAD
uid="uid", utid="utid", # They will form client_info
uid=None, utid=None, # If present, they will form client_info
access_token=None, expires_in=3600, token_type="some type",
refresh_token=None,
foci=None,
id_token=None, # or something generated by build_id_token()
error=None,
):
response = {
"client_info": base64.b64encode(json.dumps({
response = {}
if uid and utid: # Mimic the AAD behavior for "client_info=1" request
response["client_info"] = base64.b64encode(json.dumps({
"uid": uid, "utid": utid,
}).encode()).decode('utf-8'),
}
}).encode()).decode('utf-8')
if error:
response["error"] = error
if access_token:
Expand All @@ -59,7 +58,7 @@ def build_response( # simulate a response from AAD
def setUp(self):
self.cache = TokenCache()

def testAdd(self):
def testAddByAad(self):
client_id = "my_client_id"
id_token = self.build_id_token(
oid="object1234", preferred_username="John Doe", aud=client_id)
Expand Down Expand Up @@ -132,6 +131,78 @@ def testAdd(self):
"appmetadata-login.example.com-my_client_id")
)

def testAddByAdfs(self):
client_id = "my_client_id"
id_token = self.build_id_token(aud=client_id, upn="[email protected]")
self.cache.add({
"client_id": client_id,
"scope": ["s2", "s1", "s3"], # Not in particular order
"token_endpoint": "https://fs.msidlab8.com/adfs/oauth2/token",
"response": self.build_response(
uid=None, utid=None, # ADFS will provide no client_info
expires_in=3600, access_token="an access token",
id_token=id_token, refresh_token="a refresh token"),
}, now=1000)
self.assertEqual(
{
'cached_at': "1000",
'client_id': 'my_client_id',
'credential_type': 'AccessToken',
'environment': 'fs.msidlab8.com',
'expires_on': "4600",
'extended_expires_on': "4600",
'home_account_id': "subject",
'realm': 'adfs',
'secret': 'an access token',
'target': 's2 s1 s3',
},
self.cache._cache["AccessToken"].get(
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s2 s1 s3')
)
self.assertEqual(
{
'client_id': 'my_client_id',
'credential_type': 'RefreshToken',
'environment': 'fs.msidlab8.com',
'home_account_id': "subject",
'secret': 'a refresh token',
'target': 's2 s1 s3',
},
self.cache._cache["RefreshToken"].get(
'subject-fs.msidlab8.com-refreshtoken-my_client_id--s2 s1 s3')
)
self.assertEqual(
{
'home_account_id': "subject",
'environment': 'fs.msidlab8.com',
'realm': 'adfs',
'local_account_id': "subject",
'username': "[email protected]",
'authority_type': "ADFS",
},
self.cache._cache["Account"].get('subject-fs.msidlab8.com-adfs')
)
self.assertEqual(
{
'credential_type': 'IdToken',
'secret': id_token,
'home_account_id': "subject",
'environment': 'fs.msidlab8.com',
'realm': 'adfs',
'client_id': 'my_client_id',
},
self.cache._cache["IdToken"].get(
'subject-fs.msidlab8.com-idtoken-my_client_id-adfs-')
)
self.assertEqual(
{
"client_id": "my_client_id",
'environment': 'fs.msidlab8.com',
},
self.cache._cache.get("AppMetadata", {}).get(
"appmetadata-fs.msidlab8.com-my_client_id")
)


class SerializableTokenCacheTestCase(TokenCacheTestCase):
# Run all inherited test methods, and have extra check in tearDown()
Expand Down