Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import sys
import warnings
import uuid

import requests

Expand Down Expand Up @@ -68,6 +69,16 @@ def extract_certs(public_cert_content):

class ClientApplication(object):

ACQUIRE_TOKEN_SILENT_ID = "84"
ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID = "301"
ACQUIRE_TOKEN_ON_BEHALF_OF_ID = "523"
INITIATE_DEVICE_FLOW = "621"
ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID = "622"
ACQUIRE_TOKEN_FOR_CLIENT_ID = "730"
ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID = "832"
GET_ACCOUNTS_ID = "902"
REMOVE_ACCOUNT_ID = "903"

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
Expand Down Expand Up @@ -275,6 +286,9 @@ def acquire_token_by_authorization_code(
data=dict(
kwargs.pop("data", {}),
scope=decorate_scope(scopes, self.client_id)),
headers={'client-request-id', self.get_new_correlation_id(),
'x-client-current-telemetry', self._build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID)},
**kwargs)

def get_accounts(self, username=None):
Expand Down Expand Up @@ -498,7 +512,7 @@ def _get_app_metadata(self, environment):

def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False, **kwargs):
rt_remover=None, break_condition=lambda response: False, force_refresh=False, **kwargs):
matches = self.token_cache.find(
self.token_cache.CredentialType.REFRESH_TOKEN,
# target=scopes, # AAD RTs are scope-independent
Expand All @@ -511,6 +525,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
entry, rt_getter=lambda token_item: token_item["secret"],
on_removing_rt=rt_remover or self.token_cache.remove_rt,
scope=scopes,
headers={'client-request-id': self.get_new_correlation_id(),
'x-client-current-telemetry': self._build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_SILENT_ID, force_refresh)},

**kwargs)
if "error" not in response:
return response
Expand All @@ -533,6 +551,12 @@ def _validate_ssh_cert_input_data(self, data):
"you must include a string parameter named 'key_id' "
"which identifies the key in the 'req_cnf' argument.")

def get_new_correlation_id(self):
return str(uuid.uuid4())

def _build_current_telemetry_request_header(self, public_api_id, force_refresh=False):
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")


class PublicClientApplication(ClientApplication): # browser app or mobile app

Expand All @@ -554,6 +578,8 @@ def initiate_device_flow(self, scopes=None, **kwargs):
- an error response would contain some other readable key/value pairs.
"""
return self.client.initiate_device_flow(
{'client-request-id': self.get_new_correlation_id(),
'x-client-current-telemetry': self._build_current_telemetry_request_header(self.INITIATE_DEVICE_FLOW)},
scope=decorate_scope(scopes or [], self.client_id),
**kwargs)

Expand All @@ -577,6 +603,9 @@ def acquire_token_by_device_flow(self, flow, **kwargs):
# 2018-10-4 Hack:
# during transition period,
# service seemingly need both device_code and code parameter.
headers={'client-request-id': self.get_new_correlation_id(),
'x-client-current-telemetry': self._build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID)},
**kwargs)

def acquire_token_by_username_password(
Expand All @@ -597,13 +626,18 @@ def acquire_token_by_username_password(
- an error response would contain "error" and usually "error_description".
"""
scopes = decorate_scope(scopes, self.client_id)
headers = {'client-request-id': self.get_new_correlation_id(),
'x-client-current-telemetry': self._build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID)}
if not self.authority.is_adfs:
user_realm_result = self.authority.user_realm_discovery(username)
user_realm_result = self.authority.user_realm_discovery(username, headers)
if user_realm_result.get("account_type") == "Federated":
return self._acquire_token_by_username_password_federated(
user_realm_result, username, password, scopes=scopes, **kwargs)
user_realm_result, username, password, scopes=scopes, headers=headers, **kwargs)
return self.client.obtain_token_by_username_password(
username, password, scope=scopes, **kwargs)
username, password, scope=scopes,
headers=headers,
**kwargs)

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
Expand Down Expand Up @@ -660,6 +694,9 @@ def acquire_token_for_client(self, scopes, **kwargs):
# TBD: force_refresh behavior
return self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={'client-request-id': self.get_new_correlation_id(),
'x-client-current-telemetry': self._build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)},
**kwargs)

def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs):
Expand Down Expand Up @@ -695,5 +732,8 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs):
# so that the calling app could use id_token_claims to implement
# their own cache mapping, which is likely needed in web apps.
data=dict(kwargs.pop("data", {}), requested_token_use="on_behalf_of"),
headers={'client-request-id': self.get_new_correlation_id(),
'x-client-current-telemetry': self._build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID)},
**kwargs)

4 changes: 2 additions & 2 deletions msal/authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def __init__(self, authority_url, validate_authority=True,
_, _, self.tenant = canonicalize(self.token_endpoint) # Usually a GUID
self.is_adfs = self.tenant.lower() == 'adfs'

def user_realm_discovery(self, username):
def user_realm_discovery(self, username, headers):
resp = requests.get(
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
netloc=self.instance, username=username),
headers={'Accept':'application/json'},
headers=headers.update({'Accept': 'application/json'}),
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
resp.raise_for_status()
return resp.json()
Expand Down
12 changes: 6 additions & 6 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,8 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" # RFC7523
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}


def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
# type: (list, **dict) -> dict
def initiate_device_flow(self, headers, scope=None, timeout=None, **kwargs):
# type: (dict, list, **dict) -> dict
# The naming of this method is following the wording of this specs
# https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1
"""Initiate a device flow.
Expand All @@ -214,9 +213,10 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
if not self.configuration.get(DAE):
raise ValueError("You need to provide device authorization endpoint")
flow = self.session.post(self.configuration[DAE],
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
timeout=timeout or self.timeout,
**kwargs).json()
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
headers=headers,
timeout=timeout or self.timeout,
**kwargs).json()
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
flow["expires_in"] = int(flow.get("expires_in", 1800))
flow["expires_at"] = time.time() + flow["expires_in"] # We invent this
Expand Down