Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 86 additions & 16 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import sys
import warnings
import uuid

import requests

Expand Down Expand Up @@ -49,6 +50,16 @@ def decorate_scope(
decorated = scope_set | reserved_scope
return list(decorated)

CLIENT_REQUEST_ID = 'client-request-id'
CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry'

def _get_new_correlation_id():
return str(uuid.uuid4())


def _build_current_telemetry_request_header(public_api_id, force_refresh=False):
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")


def extract_certs(public_cert_content):
# Parses raw public certificate file contents and returns a list of strings
Expand All @@ -68,6 +79,15 @@ def extract_certs(public_cert_content):

class ClientApplication(object):

ACQUIRE_TOKEN_SILENT_ID = "84"
ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID = "301"
ACQUIRE_TOKEN_ON_BEHALF_OF_ID = "523"
ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID = "622"
ACQUIRE_TOKEN_FOR_CLIENT_ID = "730"
ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID = "832"
GET_ACCOUNTS_ID = "902"
REMOVE_ACCOUNT_ID = "903"

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
Expand Down Expand Up @@ -303,6 +323,11 @@ def acquire_token_by_authorization_code(
data=dict(
kwargs.pop("data", {}),
scope=decorate_scope(scopes, self.client_id)),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID),
},
**kwargs)

def get_accounts(self, username=None):
Expand Down Expand Up @@ -426,14 +451,17 @@ def acquire_token_silent(
"""
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
correlation_id = _get_new_correlation_id()
if authority:
warnings.warn("We haven't decided how/if this method will accept authority parameter")
# the_authority = Authority(
# authority,
# verify=self.verify, proxies=self.proxies, timeout=self.timeout,
# ) if authority else self.authority
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, self.authority, force_refresh=force_refresh, **kwargs)
scopes, account, self.authority, force_refresh=force_refresh,
correlation_id=correlation_id,
**kwargs)
if result:
return result
for alias in self._get_authority_aliases(self.authority.instance):
Expand All @@ -442,7 +470,9 @@ def acquire_token_silent(
validate_authority=False,
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, the_authority, force_refresh=force_refresh, **kwargs)
scopes, account, the_authority, force_refresh=force_refresh,
correlation_id=correlation_id,
**kwargs)
if result:
return result

Expand Down Expand Up @@ -480,7 +510,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
}
return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, decorate_scope(scopes, self.client_id), account,
**kwargs)
force_refresh=force_refresh, **kwargs)

def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
self, authority, scopes, account, **kwargs):
Expand Down Expand Up @@ -526,7 +556,8 @@ def _get_app_metadata(self, environment):

def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False, **kwargs):
rt_remover=None, break_condition=lambda response: False,
force_refresh=False, correlation_id=None, **kwargs):
matches = self.token_cache.find(
self.token_cache.CredentialType.REFRESH_TOKEN,
# target=scopes, # AAD RTs are scope-independent
Expand All @@ -539,6 +570,11 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
entry, rt_getter=lambda token_item: token_item["secret"],
on_removing_rt=rt_remover or self.token_cache.remove_rt,
scope=scopes,
headers={
CLIENT_REQUEST_ID: correlation_id or _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_SILENT_ID, force_refresh=force_refresh),
},
**kwargs)
if "error" not in response:
return response
Expand All @@ -564,6 +600,8 @@ def _validate_ssh_cert_input_data(self, data):

class PublicClientApplication(ClientApplication): # browser app or mobile app

DEVICE_FLOW_CORRELATION_ID = "_correlation_id"

def __init__(self, client_id, client_credential=None, **kwargs):
if client_credential is not None:
raise ValueError("Public Client should not possess credentials")
Expand All @@ -581,9 +619,16 @@ def initiate_device_flow(self, scopes=None, **kwargs):
- A successful response would contain "user_code" key, among others
- an error response would contain some other readable key/value pairs.
"""
return self.client.initiate_device_flow(
correlation_id = _get_new_correlation_id()
flow = self.client.initiate_device_flow(
scope=decorate_scope(scopes or [], self.client_id),
headers={
CLIENT_REQUEST_ID: correlation_id,
# CLIENT_CURRENT_TELEMETRY is not currently required
},
**kwargs)
flow[self.DEVICE_FLOW_CORRELATION_ID] = correlation_id
return flow

def acquire_token_by_device_flow(self, flow, **kwargs):
"""Obtain token by a device flow object, with customizable polling effect.
Expand All @@ -600,12 +645,18 @@ def acquire_token_by_device_flow(self, flow, **kwargs):
- an error response would contain "error" and usually "error_description".
"""
return self.client.obtain_token_by_device_flow(
flow,
data=dict(kwargs.pop("data", {}), code=flow["device_code"]),
# 2018-10-4 Hack:
# during transition period,
# service seemingly need both device_code and code parameter.
**kwargs)
flow,
data=dict(kwargs.pop("data", {}), code=flow["device_code"]),
# 2018-10-4 Hack:
# during transition period,
# service seemingly need both device_code and code parameter.
headers={
CLIENT_REQUEST_ID:
flow.get(self.DEVICE_FLOW_CORRELATION_ID) or _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID),
},
**kwargs)

def acquire_token_by_username_password(
self, username, password, scopes, **kwargs):
Expand All @@ -625,13 +676,22 @@ def acquire_token_by_username_password(
- an error response would contain "error" and usually "error_description".
"""
scopes = decorate_scope(scopes, self.client_id)
headers = {
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID),
}
if not self.authority.is_adfs:
user_realm_result = self.authority.user_realm_discovery(username)
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[CLIENT_REQUEST_ID])
if user_realm_result.get("account_type") == "Federated":
return self._acquire_token_by_username_password_federated(
user_realm_result, username, password, scopes=scopes, **kwargs)
user_realm_result, username, password, scopes=scopes,
headers=headers, **kwargs)
return self.client.obtain_token_by_username_password(
username, password, scope=scopes, **kwargs)
username, password, scope=scopes,
headers=headers,
**kwargs)

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
Expand Down Expand Up @@ -687,8 +747,13 @@ def acquire_token_for_client(self, scopes, **kwargs):
"""
# TBD: force_refresh behavior
return self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
**kwargs)
scope=scopes, # This grant flow requires no scope decoration
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID),
},
**kwargs)

def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs):
"""Acquires token using on-behalf-of (OBO) flow.
Expand Down Expand Up @@ -723,5 +788,10 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs):
# so that the calling app could use id_token_claims to implement
# their own cache mapping, which is likely needed in web apps.
data=dict(kwargs.pop("data", {}), requested_token_use="on_behalf_of"),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID),
},
**kwargs)

5 changes: 3 additions & 2 deletions msal/authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,16 @@ def __init__(self, authority_url, validate_authority=True,
_, _, self.tenant = canonicalize(self.token_endpoint) # Usually a GUID
self.is_adfs = self.tenant.lower() == 'adfs'

def user_realm_discovery(self, username, response=None):
def user_realm_discovery(self, username, correlation_id=None, response=None):
# It will typically return a dict containing "ver", "account_type",
# "federation_protocol", "cloud_audience_urn",
# "federation_metadata_url", "federation_active_auth_url", etc.
if self.instance not in self.__class__._domains_without_user_realm_discovery:
resp = response or requests.get(
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
netloc=self.instance, username=username),
headers={'Accept':'application/json'},
headers={'Accept':'application/json',
'client-request-id': correlation_id},
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
if resp.status_code != 404:
resp.raise_for_status()
Expand Down