diff --git a/msal/application.py b/msal/application.py index d6fb131a..f6e032bc 100644 --- a/msal/application.py +++ b/msal/application.py @@ -14,6 +14,7 @@ import requests from .oauth2cli import Client, JwtAssertionCreator +from .oauth2cli.oidc import decode_part from .authority import Authority from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request @@ -25,7 +26,7 @@ # The __init__.py will import this. Not the other way around. -__version__ = "1.13.0" +__version__ = "1.14.0" logger = logging.getLogger(__name__) @@ -111,6 +112,34 @@ def _preferred_browser(): return None +class _ClientWithCcsRoutingInfo(Client): + + def initiate_auth_code_flow(self, **kwargs): + return super(_ClientWithCcsRoutingInfo, self).initiate_auth_code_flow( + client_info=1, # To be used as CSS Routing info + **kwargs) + + def obtain_token_by_auth_code_flow( + self, auth_code_flow, auth_response, **kwargs): + # Note: the obtain_token_by_browser() is also covered by this + assert isinstance(auth_code_flow, dict) and isinstance(auth_response, dict) + headers = kwargs.pop("headers", {}) + client_info = json.loads( + decode_part(auth_response["client_info"]) + ) if auth_response.get("client_info") else {} + if "uid" in client_info and "utid" in client_info: + # Note: The value of X-AnchorMailbox is also case-insensitive + headers["X-AnchorMailbox"] = "Oid:{uid}@{utid}".format(**client_info) + return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_auth_code_flow( + auth_code_flow, auth_response, headers=headers, **kwargs) + + def obtain_token_by_username_password(self, username, password, **kwargs): + headers = kwargs.pop("headers", {}) + headers["X-AnchorMailbox"] = "upn:{}".format(username) + return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_username_password( + username, password, headers=headers, **kwargs) + + class ClientApplication(object): ACQUIRE_TOKEN_SILENT_ID = "84" @@ -139,6 +168,7 @@ def __init__( # This way, it holds the same positional param place for PCA, # when we would eventually want to add this feature to PCA in future. exclude_scopes=None, + http_cache=None, ): """Create an instance of application. @@ -174,7 +204,7 @@ def __init__( you may try use only the leaf cert (in PEM/str format) instead. *Added in version 1.13.0*: - It can also be a completly pre-signed assertion that you've assembled yourself. + It can also be a completely pre-signed assertion that you've assembled yourself. Simply pass a container containing only the key "client_assertion", like this:: { @@ -305,6 +335,46 @@ def __init__( If that is unnecessary or undesirable for your app, now you can use this parameter to supply an exclusion list of scopes, such as ``exclude_scopes = ["offline_access"]``. + + :param dict http_cache: + MSAL has long been caching tokens in the ``token_cache``. + Recently, MSAL also introduced a concept of ``http_cache``, + by automatically caching some finite amount of non-token http responses, + so that *long-lived* + ``PublicClientApplication`` and ``ConfidentialClientApplication`` + would be more performant and responsive in some situations. + + This ``http_cache`` parameter accepts any dict-like object. + If not provided, MSAL will use an in-memory dict. + + If your app is a command-line app (CLI), + you would want to persist your http_cache across different CLI runs. + The Python standard library's ``shelve`` module comes in handy. Recipe:: + + # Just add the following 3 lines at the beginning of your CLI script + import sys, atexit, shelve + persisted_http_cache = shelve.open(sys.argv[0] + ".http_cache") + atexit.register(persisted_http_cache.close) + + # And then you can implement your app as you normally would + app = msal.PublicClientApplication( + "your_client_id", + ..., + http_cache=persisted_http_cache, # Utilize persisted_http_cache + ..., + #token_cache=..., # You may combine the old token_cache trick + # Please refer to token_cache recipe at + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache + ) + app.acquire_token_interactive(["your", "scope"], ...) + + Content inside ``http_cache`` are cheap to obtain. + There is no need to share them among different apps. + + Content inside ``http_cache`` will contain no tokens nor + Personally Identifiable Information (PII). Encryption is unnecessary. + + New in version 1.15.0. """ self.client_id = client_id self.client_credential = client_credential @@ -339,7 +409,7 @@ def __init__( self.http_client.mount("https://", a) self.http_client = ThrottledHttpClient( self.http_client, - {} # Hard code an in-memory cache, for now + {} if http_cache is None else http_cache, # Default to an in-memory dict ) self.app_name = app_name @@ -481,7 +551,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False authority.device_authorization_endpoint or urljoin(authority.token_endpoint, "devicecode"), } - central_client = Client( + central_client = _ClientWithCcsRoutingInfo( central_configuration, self.client_id, http_client=self.http_client, @@ -506,7 +576,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False regional_authority.device_authorization_endpoint or urljoin(regional_authority.token_endpoint, "devicecode"), } - regional_client = Client( + regional_client = _ClientWithCcsRoutingInfo( regional_configuration, self.client_id, http_client=self.http_client, @@ -529,6 +599,7 @@ def initiate_auth_code_flow( login_hint=None, # type: Optional[str] domain_hint=None, # type: Optional[str] claims_challenge=None, + max_age=None, ): """Initiate an auth code flow. @@ -559,6 +630,17 @@ def initiate_auth_code_flow( `here `_ and `here `_. + :param int max_age: + OPTIONAL. Maximum Authentication Age. + Specifies the allowable elapsed time in seconds + since the last time the End-User was actively authenticated. + If the elapsed time is greater than this value, + Microsoft identity platform will actively re-authenticate the End-User. + + MSAL Python will also automatically validate the auth_time in ID token. + + New in version 1.15. + :return: The auth code flow. It is a dict in this form:: @@ -577,7 +659,7 @@ def initiate_auth_code_flow( 3. and then relay this dict and subsequent auth response to :func:`~acquire_token_by_auth_code_flow()`. """ - client = Client( + client = _ClientWithCcsRoutingInfo( {"authorization_endpoint": self.authority.authorization_endpoint}, self.client_id, http_client=self.http_client) @@ -588,6 +670,7 @@ def initiate_auth_code_flow( domain_hint=domain_hint, claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge), + max_age=max_age, ) flow["claims_challenge"] = claims_challenge return flow @@ -654,7 +737,7 @@ def get_authorization_request_url( self.http_client ) if authority else self.authority - client = Client( + client = _ClientWithCcsRoutingInfo( {"authorization_endpoint": the_authority.authorization_endpoint}, self.client_id, http_client=self.http_client) @@ -1178,6 +1261,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token( key=lambda e: int(e.get("last_modification_time", "0")), reverse=True): logger.debug("Cache attempts an RT") + headers = telemetry_context.generate_headers() + if "home_account_id" in query: # Then use it as CCS Routing info + headers["X-AnchorMailbox"] = "Oid:{}".format( # case-insensitive value + query["home_account_id"].replace(".", "@")) response = client.obtain_token_by_refresh_token( entry, rt_getter=lambda token_item: token_item["secret"], on_removing_rt=lambda rt_item: None, # Disable RT removal, @@ -1189,7 +1276,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token( skip_account_creation=True, # To honor a concurrent remove_account() )), scope=scopes, - headers=telemetry_context.generate_headers(), + headers=headers, data=dict( kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( @@ -1370,6 +1457,7 @@ def acquire_token_interactive( timeout=None, port=None, extra_scopes_to_consent=None, + max_age=None, **kwargs): """Acquire token interactively i.e. via a local browser. @@ -1415,6 +1503,17 @@ def acquire_token_interactive( in the same interaction, but for which you won't get back a token for in this particular operation. + :param int max_age: + OPTIONAL. Maximum Authentication Age. + Specifies the allowable elapsed time in seconds + since the last time the End-User was actively authenticated. + If the elapsed time is greater than this value, + Microsoft identity platform will actively re-authenticate the End-User. + + MSAL Python will also automatically validate the auth_time in ID token. + + New in version 1.15. + :return: - A dict containing no "error" key, and typically contains an "access_token" key. @@ -1433,6 +1532,7 @@ def acquire_token_interactive( port=port or 0), prompt=prompt, login_hint=login_hint, + max_age=max_age, timeout=timeout, auth_params={ "claims": claims, diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py index 24e3f642..85bbd889 100644 --- a/msal/oauth2cli/authcode.py +++ b/msal/oauth2cli/authcode.py @@ -8,6 +8,8 @@ import logging import socket from string import Template +import threading +import time try: # Python 3 from http.server import HTTPServer, BaseHTTPRequestHandler @@ -143,17 +145,14 @@ def __init__(self, port=None): # TODO: But, it would treat "localhost" or "" as IPv4. # If pressed, we might just expose a family parameter to caller. self._server = Server((address, port or 0), _AuthCodeHandler) + self._closing = False def get_port(self): """The port this server actually listening to""" # https://docs.python.org/2.7/library/socketserver.html#SocketServer.BaseServer.server_address return self._server.server_address[1] - def get_auth_response(self, auth_uri=None, timeout=None, state=None, - welcome_template=None, success_template=None, error_template=None, - auth_uri_callback=None, - browser_name=None, - ): + def get_auth_response(self, timeout=None, **kwargs): """Wait and return the auth response. Raise RuntimeError when timeout. :param str auth_uri: @@ -192,6 +191,37 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None, and https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse Returns None when the state was mismatched, or when timeout occurred. """ + # Historically, the _get_auth_response() uses HTTPServer.handle_request(), + # because its handle-and-retry logic is conceptually as easy as a while loop. + # Also, handle_request() honors server.timeout setting, and CTRL+C simply works. + # All those are true when running on Linux. + # + # However, the behaviors on Windows turns out to be different. + # A socket server waiting for request would freeze the current thread. + # Neither timeout nor CTRL+C would work. End user would have to do CTRL+BREAK. + # https://stackoverflow.com/questions/1364173/stopping-python-using-ctrlc + # + # The solution would need to somehow put the http server into its own thread. + # This could be done by the pattern of ``http.server.test()`` which internally + # use ``ThreadingHTTPServer.serve_forever()`` (only available in Python 3.7). + # Or create our own thread to wrap the HTTPServer.handle_request() inside. + result = {} # A mutable object to be filled with thread's return value + t = threading.Thread( + target=self._get_auth_response, args=(result,), kwargs=kwargs) + t.daemon = True # So that it won't prevent the main thread from exiting + t.start() + begin = time.time() + while (time.time() - begin < timeout) if timeout else True: + time.sleep(1) # Short detection interval to make happy path responsive + if not t.is_alive(): # Then the thread has finished its job and exited + break + return result or None + + def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None, + welcome_template=None, success_template=None, error_template=None, + auth_uri_callback=None, + browser_name=None, + ): welcome_uri = "http://localhost:{p}".format(p=self.get_port()) abort_uri = "{loc}?error=abort".format(loc=welcome_uri) logger.debug("Abort by visit %s", abort_uri) @@ -229,7 +259,8 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None, self._server.timeout = timeout # Otherwise its handle_timeout() won't work self._server.auth_response = {} # Shared with _AuthCodeHandler - while True: + while not self._closing: # Otherwise, the handle_request() attempt + # would yield noisy ValueError trace # Derived from # https://docs.python.org/2/library/basehttpserver.html#more-examples self._server.handle_request() @@ -238,10 +269,11 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None, logger.debug("State mismatch. Ignoring this noise.") else: break - return self._server.auth_response + result.update(self._server.auth_response) # Return via writable result param def close(self): """Either call this eventually; or use the entire class as context manager""" + self._closing = True self._server.server_close() def __enter__(self): diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 305061cf..8d337bb9 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -199,7 +199,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data["client_assertion"] = encoder( self.client_assertion() # Do lazy on-the-fly computation if callable(self.client_assertion) else self.client_assertion - ) # The type is bytes, which is preferrable. See also: + ) # The type is bytes, which is preferable. See also: # https://github.com/psf/requests/issues/4503#issuecomment-455001070 _data.update(self.default_body) # It may contain authen parameters diff --git a/msal/oauth2cli/oidc.py b/msal/oauth2cli/oidc.py index 114693b1..4f1ca2bd 100644 --- a/msal/oauth2cli/oidc.py +++ b/msal/oauth2cli/oidc.py @@ -42,7 +42,7 @@ def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None) """ decoded = json.loads(decode_part(id_token.split('.')[1])) err = None # https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation - _now = now or time.time() + _now = int(now or time.time()) skew = 120 # 2 minutes if _now + skew < decoded.get("nbf", _now - 1): # nbf is optional per JWT specs # This is not an ID token validation, but a JWT validation @@ -67,14 +67,14 @@ def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None) # the Client and the Token Endpoint (which it is during _obtain_token()), # the TLS server validation MAY be used to validate the issuer # in place of checking the token signature. - if _now > decoded["exp"]: + if _now - skew > decoded["exp"]: err = "9. The current time MUST be before the time represented by the exp Claim." if nonce and nonce != decoded.get("nonce"): err = ("11. Nonce must be the same value " "as the one that was sent in the Authentication Request.") if err: - raise RuntimeError("%s The id_token was: %s" % ( - err, json.dumps(decoded, indent=2))) + raise RuntimeError("%s Current epoch = %s. The id_token was: %s" % ( + err, _now, json.dumps(decoded, indent=2))) return decoded @@ -187,6 +187,8 @@ def initiate_auth_code_flow( flow = super(Client, self).initiate_auth_code_flow( scope=_scope, nonce=_nonce_hash(nonce), **kwargs) flow["nonce"] = nonce + if kwargs.get("max_age") is not None: + flow["max_age"] = kwargs["max_age"] return flow def obtain_token_by_auth_code_flow(self, auth_code_flow, auth_response, **kwargs): @@ -208,6 +210,26 @@ def obtain_token_by_auth_code_flow(self, auth_code_flow, auth_response, **kwargs raise RuntimeError( 'The nonce in id token ("%s") should match our nonce ("%s")' % (nonce_in_id_token, expected_hash)) + + if auth_code_flow.get("max_age") is not None: + auth_time = result.get("id_token_claims", {}).get("auth_time") + if not auth_time: + raise RuntimeError( + "13. max_age was requested, ID token should contain auth_time") + now = int(time.time()) + skew = 120 # 2 minutes. Hardcoded, for now + if now - skew > auth_time + auth_code_flow["max_age"]: + raise RuntimeError( + "13. auth_time ({auth_time}) was requested, " + "by using max_age ({max_age}) parameter, " + "and now ({now}) too much time has elasped " + "since last end-user authentication. " + "The ID token was: {id_token}".format( + auth_time=auth_time, + max_age=auth_code_flow["max_age"], + now=now, + id_token=json.dumps(result["id_token_claims"], indent=2), + )) return result def obtain_token_by_browser( diff --git a/tests/test_application.py b/tests/test_application.py index ea98b16f..5a92c8d4 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -5,7 +5,7 @@ import msal from msal.application import _merge_claims_challenge_and_capabilities from tests import unittest -from tests.test_token_cache import TokenCacheTestCase +from tests.test_token_cache import build_id_token, build_response from tests.http_client import MinimalHttpClient, MinimalResponse from msal.telemetry import CLIENT_CURRENT_TELEMETRY, CLIENT_LAST_TELEMETRY @@ -66,7 +66,7 @@ def setUp(self): "client_id": self.client_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( + "response": build_response( access_token="an expired AT to trigger refresh", expires_in=-99, uid=self.uid, utid=self.utid, refresh_token=self.rt), }) # The add(...) helper populates correct home_account_id for future searching @@ -125,9 +125,9 @@ def setUp(self): "client_id": self.preexisting_family_app_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( + "response": build_response( access_token="Siblings won't share AT. test_remove_account() will.", - id_token=TokenCacheTestCase.build_id_token(aud=self.preexisting_family_app_id), + id_token=build_id_token(aud=self.preexisting_family_app_id), uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"), }) # The add(...) helper populates correct home_account_id for future searching @@ -153,8 +153,7 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): "client_id": app.client_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( - uid=self.uid, utid=self.utid, refresh_token=rt), + "response": build_response(uid=self.uid, utid=self.utid, refresh_token=rt), }) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) def tester(url, data=None, **kwargs): @@ -168,7 +167,7 @@ def tester(url, data=None, **kwargs): self.assertEqual( self.frt, data.get("refresh_token"), "Should attempt the FRT") return MinimalResponse( - status_code=200, text=json.dumps(TokenCacheTestCase.build_response( + status_code=200, text=json.dumps(build_response( uid=self.uid, utid=self.utid, foci="1", access_token="at"))) app = ClientApplication( "unknown_family_app", authority=self.authority_url, token_cache=self.cache) @@ -246,7 +245,7 @@ def setUp(self): "scope": self.scopes, "token_endpoint": "https://{}/common/oauth2/v2.0/token".format( self.environment_in_cache), - "response": TokenCacheTestCase.build_response( + "response": build_response( uid=uid, utid=utid, access_token=self.access_token, refresh_token="some refresh token"), }) # The add(...) helper populates correct home_account_id for future searching @@ -342,7 +341,7 @@ def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200): "client_id": self.client_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( + "response": build_response( access_token=access_token, expires_in=expires_in, refresh_in=refresh_in, uid=self.uid, utid=self.utid, refresh_token=self.rt), @@ -424,7 +423,7 @@ def populate_cache(self, cache, access_token="at"): "client_id": self.client_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( + "response": build_response( access_token=access_token, uid=self.uid, utid=self.utid, refresh_token=self.rt), }) @@ -571,9 +570,9 @@ def test_get_accounts(self): "scope": scopes, "token_endpoint": "https://{}/{}/oauth2/v2.0/token".format(environment, tenant), - "response": TokenCacheTestCase.build_response( + "response": build_response( uid=uid, utid=utid, access_token="at", refresh_token="rt", - id_token=TokenCacheTestCase.build_id_token( + id_token=build_id_token( aud=client_id, sub="oid_in_" + tenant, preferred_username=username, diff --git a/tests/test_ccs.py b/tests/test_ccs.py new file mode 100644 index 00000000..8b801773 --- /dev/null +++ b/tests/test_ccs.py @@ -0,0 +1,73 @@ +import unittest +try: + from unittest.mock import patch, ANY +except: + from mock import patch, ANY + +from tests.http_client import MinimalResponse +from tests.test_token_cache import build_response + +import msal + + +class TestCcsRoutingInfoTestCase(unittest.TestCase): + + def test_acquire_token_by_auth_code_flow(self): + app = msal.ClientApplication("client_id") + state = "foo" + flow = app.initiate_auth_code_flow( + ["some", "scope"], login_hint="johndoe@contoso.com", state=state) + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + app.acquire_token_by_auth_code_flow(flow, { + "state": state, + "code": "bar", + "client_info": # MSAL asks for client_info, so it would be available + "eyJ1aWQiOiJhYTkwNTk0OS1hMmI4LTRlMGEtOGFlYS1iMzJlNTNjY2RiNDEiLCJ1dGlkIjoiNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3In0", + }) + self.assertEqual( + "Oid:aa905949-a2b8-4e0a-8aea-b32e53ccdb41@72f988bf-86f1-41af-91ab-2d7cd011db47", + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from client_info") + + # I've manually tested acquire_token_interactive. No need to automate it, + # because it and acquire_token_by_auth_code_flow() share same code path. + + def test_acquire_token_silent(self): + uid = "foo" + utid = "bar" + client_id = "my_client_id" + scopes = ["some", "scope"] + authority_url = "https://login.microsoftonline.com/common" + token_cache = msal.TokenCache() + token_cache.add({ # Pre-populate the cache + "client_id": client_id, + "scope": scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(authority_url), + "response": build_response( + access_token="an expired AT to trigger refresh", expires_in=-99, + uid=uid, utid=utid, refresh_token="this is a RT"), + }) # The add(...) helper populates correct home_account_id for future searching + app = msal.ClientApplication( + client_id, authority=authority_url, token_cache=token_cache) + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + account = {"home_account_id": "{}.{}".format(uid, utid)} + app.acquire_token_silent(["scope"], account) + self.assertEqual( + "Oid:{}@{}".format( # Server accepts case-insensitive value + uid, utid), # It would look like "Oid:foo@bar" + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from home_account_id") + + def test_acquire_token_by_username_password(self): + app = msal.ClientApplication("client_id") + username = "johndoe@contoso.com" + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + app.acquire_token_by_username_password(username, "password", ["scope"]) + self.assertEqual( + "upn:" + username, + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from client_info") + diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 20afaa0a..2defecd6 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -516,8 +516,8 @@ def _test_acquire_token_by_auth_code_flow( client_id, authority=authority, http_client=MinimalHttpClient()) with AuthCodeReceiver(port=port) as receiver: flow = self.app.initiate_auth_code_flow( + scope, redirect_uri="http://localhost:%d" % receiver.get_port(), - scopes=scope, ) auth_response = receiver.get_auth_response( auth_uri=flow["auth_uri"], state=flow["state"], timeout=60, diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 3cce0c82..2fe486c2 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -11,52 +11,56 @@ logging.basicConfig(level=logging.DEBUG) -class TokenCacheTestCase(unittest.TestCase): +# NOTE: These helpers were once implemented as static methods in TokenCacheTestCase. +# That would cause other test files' "from ... import TokenCacheTestCase" +# to re-run all test cases in this file. +# Now we avoid that, by defining these helpers in module level. +def build_id_token( + iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None, + **claims): # AAD issues "preferred_username", ADFS issues "upn" + return "header.%s.signature" % base64.b64encode(json.dumps(dict({ + "iss": iss, + "sub": sub, + "aud": aud, + "exp": exp or (time.time() + 100), + "iat": iat or time.time(), + }, **claims)).encode()).decode('utf-8') + - @staticmethod - def build_id_token( - iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None, - **claims): # AAD issues "preferred_username", ADFS issues "upn" - return "header.%s.signature" % base64.b64encode(json.dumps(dict({ - "iss": iss, - "sub": sub, - "aud": aud, - "exp": exp or (time.time() + 100), - "iat": iat or time.time(), - }, **claims)).encode()).decode('utf-8') +def build_response( # simulate a response from AAD + uid=None, utid=None, # If present, they will form client_info + access_token=None, expires_in=3600, token_type="some type", + **kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ... + ): + response = {} + if uid and utid: # Mimic the AAD behavior for "client_info=1" request + response["client_info"] = base64.b64encode(json.dumps({ + "uid": uid, "utid": utid, + }).encode()).decode('utf-8') + if access_token: + response.update({ + "access_token": access_token, + "expires_in": expires_in, + "token_type": token_type, + }) + response.update(kwargs) # Pass-through key-value pairs as top-level fields + return response - @staticmethod - def build_response( # simulate a response from AAD - uid=None, utid=None, # If present, they will form client_info - access_token=None, expires_in=3600, token_type="some type", - **kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ... - ): - response = {} - if uid and utid: # Mimic the AAD behavior for "client_info=1" request - response["client_info"] = base64.b64encode(json.dumps({ - "uid": uid, "utid": utid, - }).encode()).decode('utf-8') - if access_token: - response.update({ - "access_token": access_token, - "expires_in": expires_in, - "token_type": token_type, - }) - response.update(kwargs) # Pass-through key-value pairs as top-level fields - return response + +class TokenCacheTestCase(unittest.TestCase): def setUp(self): self.cache = TokenCache() def testAddByAad(self): client_id = "my_client_id" - id_token = self.build_id_token( + id_token = build_id_token( oid="object1234", preferred_username="John Doe", aud=client_id) self.cache.add({ "client_id": client_id, "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": self.build_response( + "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), @@ -125,12 +129,12 @@ def testAddByAad(self): def testAddByAdfs(self): client_id = "my_client_id" - id_token = self.build_id_token(aud=client_id, upn="JaneDoe@example.com") + id_token = build_id_token(aud=client_id, upn="JaneDoe@example.com") self.cache.add({ "client_id": client_id, "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://fs.msidlab8.com/adfs/oauth2/token", - "response": self.build_response( + "response": build_response( uid=None, utid=None, # ADFS will provide no client_info expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), @@ -204,7 +208,7 @@ def test_key_id_is_also_recorded(self): "client_id": "my_client_id", "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": self.build_response( + "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, access_token="an access token", refresh_token="a refresh token"), @@ -219,7 +223,7 @@ def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep "client_id": "my_client_id", "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": self.build_response( + "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, refresh_in=1800, access_token="an access token", ), #refresh_token="a refresh token"),