Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,23 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
force_refresh=False, # type: Optional[boolean]
**kwargs):
if not force_refresh:
matches = self.token_cache.find(
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query={
query={
"client_id": self.client_id,
"environment": authority.instance,
"realm": authority.tenant,
"home_account_id": (account or {}).get("home_account_id"),
})
"home_account_id": (account or {}).get("home_account_id"),
# Some token types (SSH-certs, POP) are bound to a key
}

key_id = kwargs.get("data", {}).get("key_id", None)
if (key_id):
query["key_id"] = key_id

matches = self.token_cache.find(
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query = query)

now = time.time()
for entry in matches:
expires_in = int(entry["expires_on"]) - now
Expand Down
18 changes: 18 additions & 0 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ def _obtain_token(self, grant_type, params=None, data=None,
*args, **kwargs):
RT = "refresh_token"
_data = data.copy() # to prevent side effect
self._validateSSHCertRequestParams(_data)

refresh_token = _data.get(RT)
if grant_type == RT and isinstance(refresh_token, dict):
_data[RT] = rt_getter(refresh_token) # Put raw RT in _data
Expand Down Expand Up @@ -461,3 +463,19 @@ def obtain_token_by_assertion(
data.update(scope=scope, assertion=encoder(assertion))
return self._obtain_token(grant_type, data=data, **kwargs)

def _validateSSHCertRequestParams(self, requestParameters):
if (requestParameters.get("token_type")):
if ("ssh-cert".casefold() == requestParameters.get("token_type").casefold()):

if (not requestParameters.get("req_cnf")):
raise ValueError(
"""When requesting an SSH certificate, you must include a string parameter named 'req_cnf' containing
the public key in JWK format (https://tools.ietf.org/html/rfc7517).""")

if (not requestParameters.get("key_id")):
raise ValueError(
"""When requesting an SSH certificate, you must include a string parameter named 'key_id'
which identifies the key in the 'req_cnf' argument""")

else:
raise ValueError("The token_type value of %s is not recognized" % requestParameters.get("token_type"))
5 changes: 5 additions & 0 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __add(self, event, now=None):
if "token_endpoint" in event:
_, environment, realm = canonicalize(event["token_endpoint"])
response = event.get("response", {})
data = event.get("data", {})
access_token = response.get("access_token")
refresh_token = response.get("refresh_token")
id_token = response.get("id_token")
Expand Down Expand Up @@ -165,6 +166,10 @@ def __add(self, event, now=None):
"expires_on": str(now + expires_in), # Same here
"extended_expires_on": str(now + ext_expires_in) # Same here
}

if (at.get("token_type").casefold() == "ssh-cert"):
at["key_id"] = data.get("key_id")

self.modify(self.CredentialType.ACCESS_TOKEN, at, at)

if client_info:
Expand Down
78 changes: 67 additions & 11 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,8 @@ def test_username_password(self):

def test_auth_code(self):
self.skipUnlessWithConfig(["client_id", "scope"])
from msal.oauth2cli.authcode import obtain_auth_code
self.app = msal.ClientApplication(
self.config["client_id"],
client_credential=self.config.get("client_secret"),
authority=self.config.get("authority"))
port = self.config.get("listen_port", 44331)
redirect_uri = "http://localhost:%s" % port
auth_request_uri = self.app.get_authorization_request_url(
self.config["scope"], redirect_uri=redirect_uri)
ac = obtain_auth_code(port, auth_uri=auth_request_uri)
self.assertNotEqual(ac, None)

(ac, redirect_uri) = self._acquire_auth_code()

result = self.app.acquire_token_by_authorization_code(
ac, self.config["scope"], redirect_uri=redirect_uri)
Expand All @@ -120,6 +111,71 @@ def test_auth_code(self):
error_description=result.get("error_description")))
self.assertCacheWorksForUser(result, self.config["scope"], username=None)


def test_ssh_cert(self):
self.skipUnlessWithConfig(["client_id", "scope"])

JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
JWK2 = """{"kty":"RSA", "n":"72u07mew8rw-ssw3tUs9clKstGO2lvD7ZNxJU7OPNKz5PGYx3gjkhUmtNah4I4FP0DuF1ogb_qSS5eD86w10Wb1ftjWcoY8zjNO9V3ph-Q2tMQWdDW5kLdeU3-EDzc0HQeou9E0udqmfQoPbuXFQcOkdcbh3eeYejs8sWn3TQprXRwGh_TRYi-CAurXXLxQ8rp-pltUVRIr1B63fXmXhMeCAGwCPEFX9FRRs-YHUszUJl9F9-E0nmdOitiAkKfCC9LhwB9_xKtjmHUM9VaEC9jWOcdvXZutwEoW2XPMOg0Ky-s197F9rfpgHle2gBrXsbvVMvS0D-wXg6vsq6BAHzQ", "e":"AQAB"}"""

(ac, redirect_uri) = self._acquire_auth_code()

result = self.app.acquire_token_by_authorization_code(
ac, self.config["scope"], redirect_uri=redirect_uri,
data={ "token_type": "ssh-cert", "key_id": "key1", "req_cnf": JWK1 },
params = self._get_ssh_test_slice())

self.assertEqual(result["token_type"], "ssh-cert")
logger.debug("%s.cache = %s",
self.id(), json.dumps(self.app.token_cache._cache, indent=4))

accessTokens = self.app.token_cache._cache.get("AccessToken")
self.assertEqual(len(accessTokens), 1)
singleAccessToken = next(iter(accessTokens.values()))
self.assertEqual(singleAccessToken.get("key_id"), "key1", "The AT should be bound to the key")

# AcquireTokenSilent needs to be passed the same key to work
account = self.app.get_accounts()[0]
result_from_cache = self.app.acquire_token_silent(
self.config["scope"],
account=account,
data={ "token_type": "ssh-cert", "key_id": "key1", "req_cnf": JWK1 })

self.assertIsNotNone(result_from_cache)
self.assertEqual(result['access_token'], result_from_cache['access_token'], "We should get the cached SSH-cert")

# refresh_token grant can fetch an ssh-cert bound to a different key
refreshed_ssh_cert = self.app.acquire_token_silent(
self.config["scope"],
account=account,
data={ "token_type": "ssh-cert", "key_id": "key2", "req_cnf": JWK2 },
params = self._get_ssh_test_slice())
self.assertIsNotNone(refreshed_ssh_cert)
self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert")
self.assertNotEqual(refreshed_ssh_cert['access_token'], result_from_cache['access_token'])

def _get_ssh_test_slice(self):
return {
"dc": "prod-wst-test1",
"slice": "test" ,
"sshcrt": "true"
}

def _acquire_auth_code(self):
from msal.oauth2cli.authcode import obtain_auth_code
self.app = msal.ClientApplication(
self.config["client_id"],
client_credential=self.config.get("client_secret"),
authority=self.config.get("authority"))
port = self.config.get("listen_port", 44331)
redirect_uri = "http://localhost:%s" % port
auth_request_uri = self.app.get_authorization_request_url(
self.config["scope"], redirect_uri=redirect_uri)
ac = obtain_auth_code(port, auth_uri=auth_request_uri)
self.assertNotEqual(ac, None)

return (ac, redirect_uri)

def test_client_secret(self):
self.skipUnlessWithConfig(["client_id", "client_secret"])
self.app = msal.ConfidentialClientApplication(
Expand Down