Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ccd64b9
Add tests
mccoyp Sep 12, 2024
34fcfed
Implement CAE support
mccoyp Sep 12, 2024
d07e0a5
Share implementation across libraries
mccoyp Sep 13, 2024
84351af
Enable CAE; provide claims only in challenges
mccoyp Sep 18, 2024
c22cb08
Update tests for success scenarios
mccoyp Sep 19, 2024
d854071
Handle non-consecutive challenges (in Keys)
mccoyp Sep 19, 2024
6c19bbc
Cover invalid challenge flows
mccoyp Sep 20, 2024
c919056
Handle (in)valid challenge flows
mccoyp Sep 20, 2024
ff731e8
Share updates across libraries
mccoyp Sep 20, 2024
237c57b
Fix spelling, pylint
mccoyp Sep 20, 2024
013673b
Update changelogs
mccoyp Sep 26, 2024
5da13ff
Update tests for feedback
mccoyp Sep 26, 2024
36cb9fd
Use super() instead of private attribute
mccoyp Sep 26, 2024
f9ff176
Add live test; assert scope
mccoyp Sep 26, 2024
bf8f054
Fix auth policy to send scope correctly
mccoyp Sep 26, 2024
850e6e8
Async tests; sync challenge policy code
mccoyp Sep 26, 2024
e78d4e9
Ensure no re-sending claims in tests
mccoyp Oct 3, 2024
1a6c9f7
Fix policy to handle KV -> KV challenge
mccoyp Oct 3, 2024
3fad4db
Share bug fix across libraries
mccoyp Oct 3, 2024
ba2a954
Clarify test variable names
mccoyp Oct 4, 2024
92c6704
Correctly handle token refreshes
mccoyp Oct 5, 2024
8e65726
Bump Core dep for SupportsTokenInfo protocol support
mccoyp Oct 8, 2024
93c7eaa
(Async)SupportsTokenInfo support/tests
mccoyp Oct 8, 2024
bf25378
Pylint
mccoyp Oct 8, 2024
3cb6481
Mention Core bump, enable_cae kwarg in changelogs
mccoyp Oct 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Share updates across libraries
  • Loading branch information
mccoyp committed Oct 7, 2024
commit ff731e8376d11d1822323e331315ad77568b92e8
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,48 @@

from copy import deepcopy
import time
from typing import Any, Optional
from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union
from typing_extensions import ParamSpec
from urllib.parse import urlparse

from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.rest import HttpRequest
from azure.core.rest import AsyncHttpResponse, HttpRequest

from . import http_challenge_cache as ChallengeCache
from .challenge_auth_policy import _enforce_tls, _update_challenge


P = ParamSpec("P")
T = TypeVar("T")


@overload
async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ...


@overload
async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ...


async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T:
"""If func returns an awaitable, await it.

:param func: The function to run.
:type func: callable
:param args: The positional arguments to pass to the function.
:type args: list
:rtype: any
:return: The result of the function
"""
result = func(*args, **kwargs)
if isinstance(result, Awaitable):
return await result
return result


class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
"""Policy for handling HTTP authentication challenges.

Expand All @@ -42,6 +72,77 @@ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any
self._token: Optional[AccessToken] = None
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
self._request_copy: Optional[HttpRequest] = None
self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token

async def send(
self, request: PipelineRequest[HttpRequest]
) -> PipelineResponse[HttpRequest, AsyncHttpResponse]:
"""Authorize request with a bearer token and send it to the next policy.

We implement this method to account for the valid scenario where a Key Vault authentication challenge is
immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to
the caller, but we should handle that second challenge as well (and only return any third 401 response).

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
await await_result(self.on_request, request)
response: PipelineResponse[HttpRequest, AsyncHttpResponse]
try:
response = await self.next.send(request)
except Exception: # pylint:disable=broad-except
await await_result(self.on_exception, request)
raise
await await_result(self.on_response, request, response)

if response.http_response.status_code == 401:
return await self.handle_challenge_flow(request, response)
return response

async def handle_challenge_flow(
self,
request: PipelineRequest[HttpRequest],
response: PipelineResponse[HttpRequest, AsyncHttpResponse],
consecutive_challenge: bool = False,
) -> PipelineResponse[HttpRequest, AsyncHttpResponse]:
"""Handle the challenge flow of Key Vault and CAE authentication.

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:param response: The pipeline response object
:type response: ~azure.core.pipeline.PipelineResponse
:param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge.
Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge.
True if the preceding challenge was a Key Vault challenge; False otherwise.

:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = await self.on_challenge(request, response)
if request_authorized:
# if we receive a challenge response, we retrieve a new token
# which matches the new target. In this case, we don't want to remove
# token from the request so clear the 'insecure_domain_change' tag
request.context.options.pop("insecure_domain_change", False)
try:
response = await self.next.send(request)
except Exception: # pylint:disable=broad-except
await await_result(self.on_exception, request)
raise

# If consecutive_challenge == True, this could be a third consecutive 401
if response.http_response.status_code == 401 and not consecutive_challenge:
# If the previous challenge wasn't from CAE, we can try this function one more time
challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
if challenge and not challenge.claims:
return await self.handle_challenge_flow(request, response, consecutive_challenge=True)
await await_result(self.on_response, request, response)
return response


async def on_request(self, request: PipelineRequest) -> None:
_enforce_tls(request)
Expand Down Expand Up @@ -78,6 +179,14 @@ async def on_request(self, request: PipelineRequest) -> None:

async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool:
try:
# CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary
old_scope: Optional[str] = None
old_tenant: Optional[str] = None
cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
if cached_challenge:
old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default"
old_tenant = cached_challenge.tenant_id

challenge = _update_challenge(request, response)
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
Expand All @@ -87,7 +196,13 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
if self._verify_challenge_resource:
resource_domain = urlparse(scope).netloc
if not resource_domain:
raise ValueError(f"The challenge contains invalid scope '{scope}'.")
# Use the old scope for CAE challenges. The parsing will succeed here since it did before
if challenge.claims and old_scope:
resource_domain = urlparse(old_scope).netloc
challenge._parameters["scope"] = old_scope
challenge.tenant_id = old_tenant
else:
raise ValueError(f"The challenge contains invalid scope '{scope}'.")

request_domain = urlparse(request.http_request.url).netloc
if not request_domain.lower().endswith(f".{resource_domain.lower()}"):
Expand All @@ -104,10 +219,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True)
await self.authorize_request(request, scope, claims=challenge.claims)
else:
await self.authorize_request(
request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id
request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id
)

return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from azure.core.exceptions import ServiceRequestError
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.rest import HttpRequest
from azure.core.rest import HttpRequest, HttpResponse

from .http_challenge import HttpChallenge
from . import http_challenge_cache as ChallengeCache
Expand Down Expand Up @@ -71,6 +71,74 @@ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) ->
self._token: Optional[AccessToken] = None
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
self._request_copy: Optional[HttpRequest] = None
self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token

def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]:
"""Authorize request with a bearer token and send it to the next policy.

We implement this method to account for the valid scenario where a Key Vault authentication challenge is
immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to
the caller, but we should handle that second challenge as well (and only return any third 401 response).

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest

:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
self.on_request(request)
try:
response = self.next.send(request)
except Exception: # pylint:disable=broad-except
self.on_exception(request)
raise

self.on_response(request, response)
if response.http_response.status_code == 401:
return self.handle_challenge_flow(request, response)
return response

def handle_challenge_flow(
self,
request: PipelineRequest[HttpRequest],
response: PipelineResponse[HttpRequest, HttpResponse],
consecutive_challenge: bool = False,
) -> PipelineResponse[HttpRequest, HttpResponse]:
"""Handle the challenge flow of Key Vault and CAE authentication.

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:param response: The pipeline response object
:type response: ~azure.core.pipeline.PipelineResponse
:param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge.
Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge.
True if the preceding challenge was a Key Vault challenge; False otherwise.

:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = self.on_challenge(request, response)
if request_authorized:
# if we receive a challenge response, we retrieve a new token
# which matches the new target. In this case, we don't want to remove
# token from the request so clear the 'insecure_domain_change' tag
request.context.options.pop("insecure_domain_change", False)
try:
response = self.next.send(request)
except Exception: # pylint:disable=broad-except
self.on_exception(request)
raise

# If consecutive_challenge == True, this could be a third consecutive 401
if response.http_response.status_code == 401 and not consecutive_challenge:
# If the previous challenge wasn't from CAE, we can try this function one more time
challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
if challenge and not challenge.claims:
return self.handle_challenge_flow(request, response, consecutive_challenge=True)
self.on_response(request, response)
return response

def on_request(self, request: PipelineRequest) -> None:
_enforce_tls(request)
Expand Down Expand Up @@ -106,6 +174,14 @@ def on_request(self, request: PipelineRequest) -> None:

def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool:
try:
# CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary
old_scope: Optional[str] = None
old_tenant: Optional[str] = None
cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
if cached_challenge:
old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default"
old_tenant = cached_challenge.tenant_id

challenge = _update_challenge(request, response)
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
Expand All @@ -115,7 +191,13 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->
if self._verify_challenge_resource:
resource_domain = urlparse(scope).netloc
if not resource_domain:
raise ValueError(f"The challenge contains invalid scope '{scope}'.")
# Use the old scope for CAE challenges. The parsing will succeed here since it did before
if challenge.claims and old_scope:
resource_domain = urlparse(old_scope).netloc
challenge._parameters["scope"] = old_scope
challenge.tenant_id = old_tenant
else:
raise ValueError(f"The challenge contains invalid scope '{scope}'.")

request_domain = urlparse(request.http_request.url).netloc
if not request_domain.lower().endswith(f".{resource_domain.lower()}"):
Expand All @@ -132,11 +214,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->
# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True)
self.authorize_request(request, scope, claims=challenge.claims)
else:
self.authorize_request(
request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id
)
self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id)

return True

Expand Down
Loading