diff --git a/msal/authority.py b/msal/authority.py index d8221eca..38391e79 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -34,7 +34,7 @@ class Authority(object): _domains_without_user_realm_discovery = set([]) def __init__(self, authority_url, validate_authority=True, - verify=True, proxies=None, timeout=None, + verify=True, proxies=None, timeout=None,openid_config=None, ): """Creates an authority instance, and also validates it. @@ -48,34 +48,19 @@ def __init__(self, authority_url, validate_authority=True, self.proxies = proxies self.timeout = timeout authority, self.instance, tenant = canonicalize(authority_url) - parts = authority.path.split('/') - is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( - len(parts) == 3 and parts[2].lower().startswith("b2c_")) - if (tenant != "adfs" and (not is_b2c) and validate_authority - and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS): - payload = instance_discovery( - "https://{}{}/oauth2/v2.0/authorize".format( - self.instance, authority.path), - verify=verify, proxies=proxies, timeout=timeout) - if payload.get("error") == "invalid_instance": - raise ValueError( - "invalid_instance: " - "The authority you provided, %s, is not whitelisted. " - "If it is indeed your legit customized domain name, " - "you can turn off this check by passing in " - "validate_authority=False" - % authority_url) - tenant_discovery_endpoint = payload['tenant_discovery_endpoint'] - else: - tenant_discovery_endpoint = ( - 'https://{}{}{}/.well-known/openid-configuration'.format( - self.instance, - authority.path, # In B2C scenario, it is "/tenant/policy" - "" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint - )) - openid_config = tenant_discovery( - tenant_discovery_endpoint, - verify=verify, proxies=proxies, timeout=timeout) + + if openid_config is None: + if requires_instance_discovery(instance=self.instance, authority=authority, tenant=tenant, validate_authority=validate_authority): + tenant_discovery_endpoint = get_tenant_discovery_endpoint(instance=self.instance, authority=authority, + authority_url=authority_url, + timeout=self.timeout, verify=self.verify, proxies=self.proxies) + else: + tenant_discovery_endpoint = default_tenant_discovery_endpoint(instance=self.instance, + authority_path=authority.path, + tenant=tenant) + + openid_config = tenant_discovery(tenant_discovery_endpoint, timeout=self.timeout, verify=self.verify, proxies=self.proxies) + logger.debug("openid_config = %s", openid_config) self.authorization_endpoint = openid_config['authorization_endpoint'] self.token_endpoint = openid_config['token_endpoint'] @@ -87,19 +72,85 @@ def user_realm_discovery(self, username, correlation_id=None, response=None): # "federation_protocol", "cloud_audience_urn", # "federation_metadata_url", "federation_active_auth_url", etc. if self.instance not in self.__class__._domains_without_user_realm_discovery: - resp = response or requests.get( - "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( - netloc=self.instance, username=username), - headers={'Accept':'application/json', - 'client-request-id': correlation_id}, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) - if resp.status_code != 404: - resp.raise_for_status() - return resp.json() + if response is None: + response = requests.get(verify=self.verify, proxies=self.proxies, timeout=self.timeout, + **get_userrealm_discovery_request_info(instance=self.instance, + username=username, + correlation_id=correlation_id)) + + discovery_payload = verify_user_realm_discovery_response(response) + + if discovery_payload is not None: + return discovery_payload + self.__class__._domains_without_user_realm_discovery.add(self.instance) return {} # This can guide the caller to fall back normal ROPC flow +def validate_instance_discovery_payload(payload, authority_url): + if payload.get("error") == "invalid_instance": + raise ValueError( + "invalid_instance: " + "The authority you provided, %s, is not whitelisted. " + "If it is indeed your legit customized domain name, " + "you can turn off this check by passing in " + "validate_authority=False" + % authority_url) + + +def verify_user_realm_discovery_response(user_realm_discovery_response): + if user_realm_discovery_response.status_code != 404: + user_realm_discovery_response.raise_for_status() + return user_realm_discovery_response.json() + + return None + + +def get_userrealm_discovery_request_info(instance, username, correlation_id): + return {"url": "https://{netloc}/common/userrealm/{username}?api-version=1.0".format(netloc=instance, + username=username), + "headers": {'Accept': 'application/json', 'client-request-id': correlation_id}} + + +def get_instance_discovery_request_info(url): + return { + "url": 'https://{}/common/discovery/instance'.format( # Note: This URL seemingly returns V1 endpoint only + WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too + # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 + # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 + ), + "params": {'authorization_endpoint': url, 'api-version': '1.0'} + } + + +def requires_instance_discovery(instance, authority, tenant, validate_authority=True): + parts = authority.path.split('/') + is_b2c = any(instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( + len(parts) == 3 and parts[2].lower().startswith("b2c_")) + return (tenant != "adfs" and (not is_b2c) and validate_authority + and instance not in WELL_KNOWN_AUTHORITY_HOSTS) + + +def get_tenant_discovery_endpoint(instance, authority, authority_url, **kwargs): + payload = instance_discovery( + url=get_instance_discovery_base_url(instance=instance, authority_path=authority.path), + **kwargs + ) + validate_instance_discovery_payload(payload, authority_url) + return payload['tenant_discovery_endpoint'] + + +def get_instance_discovery_base_url(instance, authority_path): + return "https://{}{}/oauth2/v2.0/authorize".format(instance, authority_path) + + +def default_tenant_discovery_endpoint(instance, authority_path, tenant): + return 'https://{}{}{}/.well-known/openid-configuration'.format( + instance, + authority_path, # In B2C scenario, it is "/tenant/policy" + "" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint + ) + def canonicalize(authority_url): # Returns (url_parsed_result, hostname_in_lowercase, tenant) authority = urlparse(authority_url) @@ -114,14 +165,8 @@ def canonicalize(authority_url): return authority, authority.hostname, parts[1] def instance_discovery(url, **kwargs): - return requests.get( # Note: This URL seemingly returns V1 endpoint only - 'https://{}/common/discovery/instance'.format( - WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too - # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 - # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 - ), - params={'authorization_endpoint': url, 'api-version': '1.0'}, - **kwargs).json() + kwargs.update(get_instance_discovery_request_info(url)) + return requests.get(**kwargs).json() def tenant_discovery(tenant_discovery_endpoint, **kwargs): # Returns Openid Configuration