Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import sys
import warnings
import uuid

import requests

Expand Down Expand Up @@ -50,6 +51,14 @@ def decorate_scope(
return list(decorated)


def _get_new_correlation_id():
return str(uuid.uuid4())


def _build_current_telemetry_request_header(public_api_id, force_refresh=False):
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")


def extract_certs(public_cert_content):
# Parses raw public certificate file contents and returns a list of strings
# Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())}
Expand All @@ -68,6 +77,16 @@ def extract_certs(public_cert_content):

class ClientApplication(object):

ACQUIRE_TOKEN_SILENT_ID = "84"
ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID = "301"
ACQUIRE_TOKEN_ON_BEHALF_OF_ID = "523"
INITIATE_DEVICE_FLOW = "621"
ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID = "622"
ACQUIRE_TOKEN_FOR_CLIENT_ID = "730"
ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID = "832"
GET_ACCOUNTS_ID = "902"
REMOVE_ACCOUNT_ID = "903"

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
Expand Down Expand Up @@ -283,6 +302,9 @@ def acquire_token_by_authorization_code(
data=dict(
kwargs.pop("data", {}),
scope=decorate_scope(scopes, self.client_id)),
headers={'client-request-id': _get_new_correlation_id(),
'x-client-current-telemetry': _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID)},
**kwargs)

def get_accounts(self, username=None):
Expand Down Expand Up @@ -460,7 +482,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
}
return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, decorate_scope(scopes, self.client_id), account,
**kwargs)
force_refresh=force_refresh, **kwargs)

def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
self, authority, scopes, account, **kwargs):
Expand Down Expand Up @@ -506,7 +528,8 @@ def _get_app_metadata(self, environment):

def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False, **kwargs):
rt_remover=None, break_condition=lambda response: False,
force_refresh=False, **kwargs):
matches = self.token_cache.find(
self.token_cache.CredentialType.REFRESH_TOKEN,
# target=scopes, # AAD RTs are scope-independent
Expand All @@ -519,6 +542,9 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
entry, rt_getter=lambda token_item: token_item["secret"],
on_removing_rt=rt_remover or self.token_cache.remove_rt,
scope=scopes,
headers={'client-request-id': _get_new_correlation_id(),
'x-client-current-telemetry': _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_SILENT_ID, force_refresh)},
**kwargs)
if "error" not in response:
return response
Expand Down Expand Up @@ -562,6 +588,7 @@ def initiate_device_flow(self, scopes=None, **kwargs):
- an error response would contain some other readable key/value pairs.
"""
return self.client.initiate_device_flow(
headers={'client-request-id': _get_new_correlation_id()},
scope=decorate_scope(scopes or [], self.client_id),
**kwargs)

Expand All @@ -585,6 +612,9 @@ def acquire_token_by_device_flow(self, flow, **kwargs):
# 2018-10-4 Hack:
# during transition period,
# service seemingly need both device_code and code parameter.
headers={'client-request-id': _get_new_correlation_id(),
'x-client-current-telemetry': _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID)},
**kwargs)

def acquire_token_by_username_password(
Expand All @@ -605,13 +635,18 @@ def acquire_token_by_username_password(
- an error response would contain "error" and usually "error_description".
"""
scopes = decorate_scope(scopes, self.client_id)
headers = {'client-request-id': _get_new_correlation_id(),
'x-client-current-telemetry': _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID)}
if not self.authority.is_adfs:
user_realm_result = self.authority.user_realm_discovery(username)
user_realm_result = self.authority.user_realm_discovery(username, headers['client-request-id'])
if user_realm_result.get("account_type") == "Federated":
return self._acquire_token_by_username_password_federated(
user_realm_result, username, password, scopes=scopes, **kwargs)
user_realm_result, username, password, scopes=scopes, headers=headers, **kwargs)
return self.client.obtain_token_by_username_password(
username, password, scope=scopes, **kwargs)
username, password, scope=scopes,
headers=headers,
**kwargs)

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
Expand Down Expand Up @@ -668,6 +703,9 @@ def acquire_token_for_client(self, scopes, **kwargs):
# TBD: force_refresh behavior
return self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={'client-request-id': _get_new_correlation_id(),
'x-client-current-telemetry': _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)},
**kwargs)

def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs):
Expand Down Expand Up @@ -703,5 +741,8 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs):
# so that the calling app could use id_token_claims to implement
# their own cache mapping, which is likely needed in web apps.
data=dict(kwargs.pop("data", {}), requested_token_use="on_behalf_of"),
headers={'client-request-id': _get_new_correlation_id(),
'x-client-current-telemetry': _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID)},
**kwargs)

5 changes: 3 additions & 2 deletions msal/authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ def __init__(self, authority_url, validate_authority=True,
_, _, self.tenant = canonicalize(self.token_endpoint) # Usually a GUID
self.is_adfs = self.tenant.lower() == 'adfs'

def user_realm_discovery(self, username, response=None):
def user_realm_discovery(self, username, correlation_id, response=None):
# It will typically return a dict containing "ver", "account_type",
# "federation_protocol", "cloud_audience_urn",
# "federation_metadata_url", "federation_active_auth_url", etc.
if self.instance not in self.__class__._domains_without_user_realm_discovery:
resp = response or requests.get(
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
netloc=self.instance, username=username),
headers={'Accept':'application/json'},
headers={'Accept':'application/json',
'client-request-id': correlation_id},
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
if resp.status_code != 404:
resp.raise_for_status()
Expand Down
7 changes: 3 additions & 4 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" # RFC7523
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}


def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
# type: (list, **dict) -> dict
# The naming of this method is following the wording of this specs
Expand All @@ -214,9 +213,9 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
if not self.configuration.get(DAE):
raise ValueError("You need to provide device authorization endpoint")
flow = self.session.post(self.configuration[DAE],
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
timeout=timeout or self.timeout,
**kwargs).json()
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
timeout=timeout or self.timeout,
**kwargs).json()
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
flow["expires_in"] = int(flow.get("expires_in", 1800))
flow["expires_at"] = time.time() + flow["expires_in"] # We invent this
Expand Down
5 changes: 3 additions & 2 deletions tests/test_authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ def test_memorize(self):
# We now pretend this authority supports no User Realm Discovery
class MockResponse(object):
status_code = 404
a.user_realm_discovery("[email protected]", response=MockResponse())
a.user_realm_discovery("[email protected]", "ecbecaf4-1759-498d-ae24-25a98a8eca27", response=MockResponse())
self.assertIn(
"login.microsoftonline.com",
Authority._domains_without_user_realm_discovery,
"user_realm_discovery() should memorize domains not supporting URD")
a.user_realm_discovery("[email protected]",
response="This would cause exception if memorization did not work")
"ecbecaf4-1759-498d-ae24-25a98a8eca27",
response="This would cause exception if memorization did not work")