diff --git a/msal/application.py b/msal/application.py index 17fc8dc9..20a77525 100644 --- a/msal/application.py +++ b/msal/application.py @@ -6,6 +6,7 @@ import logging import sys import warnings +import uuid import requests @@ -49,6 +50,16 @@ def decorate_scope( decorated = scope_set | reserved_scope return list(decorated) +CLIENT_REQUEST_ID = 'client-request-id' +CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry' + +def _get_new_correlation_id(): + return str(uuid.uuid4()) + + +def _build_current_telemetry_request_header(public_api_id, force_refresh=False): + return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0") + def extract_certs(public_cert_content): # Parses raw public certificate file contents and returns a list of strings @@ -68,6 +79,15 @@ def extract_certs(public_cert_content): class ClientApplication(object): + ACQUIRE_TOKEN_SILENT_ID = "84" + ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID = "301" + ACQUIRE_TOKEN_ON_BEHALF_OF_ID = "523" + ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID = "622" + ACQUIRE_TOKEN_FOR_CLIENT_ID = "730" + ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID = "832" + GET_ACCOUNTS_ID = "902" + REMOVE_ACCOUNT_ID = "903" + def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, @@ -303,6 +323,11 @@ def acquire_token_by_authorization_code( data=dict( kwargs.pop("data", {}), scope=decorate_scope(scopes, self.client_id)), + headers={ + CLIENT_REQUEST_ID: _get_new_correlation_id(), + CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( + self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID), + }, **kwargs) def get_accounts(self, username=None): @@ -426,6 +451,7 @@ def acquire_token_silent( """ assert isinstance(scopes, list), "Invalid parameter type" self._validate_ssh_cert_input_data(kwargs.get("data", {})) + correlation_id = _get_new_correlation_id() if authority: warnings.warn("We haven't decided how/if this method will accept authority parameter") # the_authority = Authority( @@ -433,7 +459,9 @@ def acquire_token_silent( # verify=self.verify, proxies=self.proxies, timeout=self.timeout, # ) if authority else self.authority result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( - scopes, account, self.authority, force_refresh=force_refresh, **kwargs) + scopes, account, self.authority, force_refresh=force_refresh, + correlation_id=correlation_id, + **kwargs) if result: return result for alias in self._get_authority_aliases(self.authority.instance): @@ -442,7 +470,9 @@ def acquire_token_silent( validate_authority=False, verify=self.verify, proxies=self.proxies, timeout=self.timeout) result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( - scopes, account, the_authority, force_refresh=force_refresh, **kwargs) + scopes, account, the_authority, force_refresh=force_refresh, + correlation_id=correlation_id, + **kwargs) if result: return result @@ -480,7 +510,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( } return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( authority, decorate_scope(scopes, self.client_id), account, - **kwargs) + force_refresh=force_refresh, **kwargs) def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self, authority, scopes, account, **kwargs): @@ -526,7 +556,8 @@ def _get_app_metadata(self, environment): def _acquire_token_silent_by_finding_specific_refresh_token( self, authority, scopes, query, - rt_remover=None, break_condition=lambda response: False, **kwargs): + rt_remover=None, break_condition=lambda response: False, + force_refresh=False, correlation_id=None, **kwargs): matches = self.token_cache.find( self.token_cache.CredentialType.REFRESH_TOKEN, # target=scopes, # AAD RTs are scope-independent @@ -539,6 +570,11 @@ def _acquire_token_silent_by_finding_specific_refresh_token( entry, rt_getter=lambda token_item: token_item["secret"], on_removing_rt=rt_remover or self.token_cache.remove_rt, scope=scopes, + headers={ + CLIENT_REQUEST_ID: correlation_id or _get_new_correlation_id(), + CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( + self.ACQUIRE_TOKEN_SILENT_ID, force_refresh=force_refresh), + }, **kwargs) if "error" not in response: return response @@ -564,6 +600,8 @@ def _validate_ssh_cert_input_data(self, data): class PublicClientApplication(ClientApplication): # browser app or mobile app + DEVICE_FLOW_CORRELATION_ID = "_correlation_id" + def __init__(self, client_id, client_credential=None, **kwargs): if client_credential is not None: raise ValueError("Public Client should not possess credentials") @@ -581,9 +619,16 @@ def initiate_device_flow(self, scopes=None, **kwargs): - A successful response would contain "user_code" key, among others - an error response would contain some other readable key/value pairs. """ - return self.client.initiate_device_flow( + correlation_id = _get_new_correlation_id() + flow = self.client.initiate_device_flow( scope=decorate_scope(scopes or [], self.client_id), + headers={ + CLIENT_REQUEST_ID: correlation_id, + # CLIENT_CURRENT_TELEMETRY is not currently required + }, **kwargs) + flow[self.DEVICE_FLOW_CORRELATION_ID] = correlation_id + return flow def acquire_token_by_device_flow(self, flow, **kwargs): """Obtain token by a device flow object, with customizable polling effect. @@ -600,12 +645,18 @@ def acquire_token_by_device_flow(self, flow, **kwargs): - an error response would contain "error" and usually "error_description". """ return self.client.obtain_token_by_device_flow( - flow, - data=dict(kwargs.pop("data", {}), code=flow["device_code"]), - # 2018-10-4 Hack: - # during transition period, - # service seemingly need both device_code and code parameter. - **kwargs) + flow, + data=dict(kwargs.pop("data", {}), code=flow["device_code"]), + # 2018-10-4 Hack: + # during transition period, + # service seemingly need both device_code and code parameter. + headers={ + CLIENT_REQUEST_ID: + flow.get(self.DEVICE_FLOW_CORRELATION_ID) or _get_new_correlation_id(), + CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( + self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID), + }, + **kwargs) def acquire_token_by_username_password( self, username, password, scopes, **kwargs): @@ -625,13 +676,22 @@ def acquire_token_by_username_password( - an error response would contain "error" and usually "error_description". """ scopes = decorate_scope(scopes, self.client_id) + headers = { + CLIENT_REQUEST_ID: _get_new_correlation_id(), + CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( + self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID), + } if not self.authority.is_adfs: - user_realm_result = self.authority.user_realm_discovery(username) + user_realm_result = self.authority.user_realm_discovery( + username, correlation_id=headers[CLIENT_REQUEST_ID]) if user_realm_result.get("account_type") == "Federated": return self._acquire_token_by_username_password_federated( - user_realm_result, username, password, scopes=scopes, **kwargs) + user_realm_result, username, password, scopes=scopes, + headers=headers, **kwargs) return self.client.obtain_token_by_username_password( - username, password, scope=scopes, **kwargs) + username, password, scope=scopes, + headers=headers, + **kwargs) def _acquire_token_by_username_password_federated( self, user_realm_result, username, password, scopes=None, **kwargs): @@ -687,8 +747,13 @@ def acquire_token_for_client(self, scopes, **kwargs): """ # TBD: force_refresh behavior return self.client.obtain_token_for_client( - scope=scopes, # This grant flow requires no scope decoration - **kwargs) + scope=scopes, # This grant flow requires no scope decoration + headers={ + CLIENT_REQUEST_ID: _get_new_correlation_id(), + CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( + self.ACQUIRE_TOKEN_FOR_CLIENT_ID), + }, + **kwargs) def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs): """Acquires token using on-behalf-of (OBO) flow. @@ -723,5 +788,10 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs): # so that the calling app could use id_token_claims to implement # their own cache mapping, which is likely needed in web apps. data=dict(kwargs.pop("data", {}), requested_token_use="on_behalf_of"), + headers={ + CLIENT_REQUEST_ID: _get_new_correlation_id(), + CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( + self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID), + }, **kwargs) diff --git a/msal/authority.py b/msal/authority.py index dae97aab..d8221eca 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -82,7 +82,7 @@ def __init__(self, authority_url, validate_authority=True, _, _, self.tenant = canonicalize(self.token_endpoint) # Usually a GUID self.is_adfs = self.tenant.lower() == 'adfs' - def user_realm_discovery(self, username, response=None): + def user_realm_discovery(self, username, correlation_id=None, response=None): # It will typically return a dict containing "ver", "account_type", # "federation_protocol", "cloud_audience_urn", # "federation_metadata_url", "federation_active_auth_url", etc. @@ -90,7 +90,8 @@ def user_realm_discovery(self, username, response=None): resp = response or requests.get( "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( netloc=self.instance, username=username), - headers={'Accept':'application/json'}, + headers={'Accept':'application/json', + 'client-request-id': correlation_id}, verify=self.verify, proxies=self.proxies, timeout=self.timeout) if resp.status_code != 404: resp.raise_for_status()