From a245cee6ab56d6befe7d45e4ac907bfb4ca357bc Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Mon, 30 Sep 2024 14:37:09 -0700 Subject: [PATCH 01/12] Add default impl to handle token challenges --- sdk/core/azure-core/CHANGELOG.md | 4 +- .../core/pipeline/policies/_authentication.py | 37 ++++++-- .../policies/_authentication_async.py | 37 ++++++-- .../azure/core/pipeline/policies/_utils.py | 89 +++++++++++++++++++ sdk/core/azure-core/tests/test_utils.py | 13 ++- 5 files changed, 160 insertions(+), 20 deletions(-) diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 2abf318002b8..470f32b7b067 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -1,9 +1,11 @@ # Release History -## 1.31.1 (Unreleased) +## 1.32.0 (Unreleased) ### Features Added +- Added a default implementation to handle token challenges in `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy`. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index dc3e23de37c8..4af989c271d3 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -4,6 +4,7 @@ # license information. # ------------------------------------------------------------------------- import time +import base64 from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast from azure.core.credentials import ( TokenCredential, @@ -19,6 +20,7 @@ from azure.core.rest import HttpResponse, HttpRequest from . import HTTPPolicy, SansIOHTTPPolicy from ...exceptions import ServiceRequestError +from ._utils import get_challenge_parameter if TYPE_CHECKING: @@ -82,13 +84,7 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - def _request_token(self, *scopes: str, **kwargs: Any) -> None: - """Request a new token from the credential. - - This will call the credential's appropriate method to get a token and store it in the policy. - - :param str scopes: The type of access needed. - """ + def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: if self._enable_cae: kwargs.setdefault("enable_cae", self._enable_cae) @@ -99,9 +95,19 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> None: if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member options[key] = kwargs.pop(key) # type: ignore[literal-required] - self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) + token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) else: - self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) + token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) + return token + + def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + self._token = self._get_token(*scopes, **kwargs) class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): @@ -191,6 +197,19 @@ def on_challenge( :rtype: bool """ # pylint:disable=unused-argument + headers = response.http_response.headers + error = get_challenge_parameter(headers, "Bearer", "error") + if error == "insufficient_claims": + encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + claims = base64.urlsafe_b64decode(encoded_claims).decode("utf-8") + if claims: + try: + token = self._get_token(*self._scopes, claims=claims) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token + self._update_headers(request.http_request.headers, bearer_token) + return True + except Exception as ex: # pylint:disable=broad-except + return False return False def on_response( diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 7fb68a606a39..e3456bfa0b2a 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -4,6 +4,7 @@ # license information. # ------------------------------------------------------------------------- import time +import base64 from typing import Any, Awaitable, Optional, cast, TypeVar, Union from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions @@ -23,6 +24,7 @@ ) from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.utils._utils import get_running_async_lock +from ._utils import get_challenge_parameter from .._tools_async import await_result @@ -138,6 +140,19 @@ async def on_challenge( :rtype: bool """ # pylint:disable=unused-argument + headers = response.http_response.headers + error = get_challenge_parameter(headers, "Bearer", "error") + if error == "insufficient_claims": + encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + claims = base64.urlsafe_b64decode(encoded_claims).decode("utf-8") + if claims: + try: + token = await self._get_token(*self._scopes, claims=claims) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token + self._update_headers(request.http_request.headers, bearer_token) + return True + except Exception as ex: # pylint:disable=broad-except + return False return False def on_response( @@ -169,13 +184,7 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - async def _request_token(self, *scopes: str, **kwargs: Any) -> None: - """Request a new token from the credential. - - This will call the credential's appropriate method to get a token and store it in the policy. - - :param str scopes: The type of access needed. - """ + async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: if self._enable_cae: kwargs.setdefault("enable_cae", self._enable_cae) @@ -186,14 +195,24 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> None: if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member options[key] = kwargs.pop(key) # type: ignore[literal-required] - self._token = await await_result( + token = await await_result( cast(AsyncSupportsTokenInfo, self._credential).get_token_info, *scopes, options=options, ) else: - self._token = await await_result( + token = await await_result( cast(AsyncTokenCredential, self._credential).get_token, *scopes, **kwargs, ) + return token + + async def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + self._token = await self._get_token(*scopes, **kwargs) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index 1733632a9ab2..8c03ad5e399c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -102,3 +102,92 @@ def get_domain(url: str) -> str: :return: The domain of the url. """ return str(urlparse(url).netloc).lower() + + +def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: str) -> Optional[str]: + """ + Parses the specified parameter from a challenge header found in the response. + + :param headers: The response headers to parse. + :param challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer". + :param challenge_parameter: The parameter key name to search for. + :return: The value of the parameter name if found. + """ + header_value = headers.get("WWW-Authenticate") + if not header_value: + return None + + scheme = challenge_scheme + parameter = challenge_parameter + header_span = header_value + + # Iterate through each challenge value. + while get_next_challenge(header_span): + challenge_key, header_span = get_next_challenge(header_span) + # Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge. + while get_next_parameter(header_span): + key, value, header_span = get_next_parameter(header_span) + if challenge_key.lower() == scheme.lower() and key.lower() == parameter.lower(): + return value + + return None + + +def get_next_challenge(header_value: str) -> Optional[tuple[str, str]]: + """ + Iterates through the challenge schemes present in a challenge header. + + :param header_value: The header value which will be sliced to remove the first parsed challenge key. + :return: The parsed challenge scheme and the remaining header value. + """ + header_value = header_value.lstrip(' ') + end_of_challenge_key = header_value.find(' ') + + if end_of_challenge_key < 0: + return None + + challenge_key = header_value[:end_of_challenge_key] + header_value = header_value[end_of_challenge_key + 1:] + + return challenge_key, header_value + + +def get_next_parameter(header_value: str, separator: str = '=') -> Optional[tuple[str, str, str]]: + """ + Iterates through a challenge header value to extract key-value parameters. + + :param header_value: The header value after being parsed by get_next_challenge. + :param separator: The challenge parameter key-value pair separator, default is '='. + :return: The next available challenge parameter as a tuple (param_key, param_value, remaining header_value). + """ + space_or_comma = " ," + header_value = header_value.lstrip(space_or_comma) + + next_space = header_value.find(' ') + next_separator = header_value.find(separator) + + if next_space < next_separator and next_space != -1: + return None + + if next_separator < 0: + return None + + param_key = header_value[:next_separator].strip() + header_value = header_value[next_separator + 1:] + + quote_index = header_value.find('"') + + if quote_index >= 0: + header_value = header_value[quote_index + 1:] + param_value = header_value[:header_value.find('"')] + else: + trailing_delimiter_index = header_value.find(' ') + if trailing_delimiter_index >= 0: + param_value = header_value[:trailing_delimiter_index] + else: + param_value = header_value + + if header_value != param_value: + header_value = header_value[len(param_value) + 1:] + + return param_key, param_value, header_value diff --git a/sdk/core/azure-core/tests/test_utils.py b/sdk/core/azure-core/tests/test_utils.py index c09b48c9c5c5..1bd2abee140a 100644 --- a/sdk/core/azure-core/tests/test_utils.py +++ b/sdk/core/azure-core/tests/test_utils.py @@ -8,7 +8,7 @@ import pytest from azure.core.utils import case_insensitive_dict from azure.core.utils._utils import get_running_async_lock -from azure.core.pipeline.policies._utils import parse_retry_after +from azure.core.pipeline.policies._utils import parse_retry_after, get_challenge_parameter @pytest.fixture() @@ -146,3 +146,14 @@ def test_parse_retry_after(): assert ret == 0 ret = parse_retry_after("0.9") assert ret == 0.9 + +def test_get_challenge_parameter(): + headers = {"WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'} + assert get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id" + assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net" + assert get_challenge_parameter(headers, "Bearer", "foo") is None + + headers = {"WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="'} + assert get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/common/oauth2/authorize" + assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims" + assert get_challenge_parameter(headers, "Bearer", "claims") == 'eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==' \ No newline at end of file From c53a91b3163a4418603687d55333de55e7b53d88 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Mon, 30 Sep 2024 14:59:54 -0700 Subject: [PATCH 02/12] update version --- sdk/core/azure-core/azure/core/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/core/azure-core/azure/core/_version.py b/sdk/core/azure-core/azure/core/_version.py index 10fcd28a3fcf..1c43dbb9b140 100644 --- a/sdk/core/azure-core/azure/core/_version.py +++ b/sdk/core/azure-core/azure/core/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "1.31.1" +VERSION = "1.32.0" From fffee09018ce9e688cf81aac20985139411e4548 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Mon, 30 Sep 2024 15:29:28 -0700 Subject: [PATCH 03/12] update --- sdk/core/azure-core/azure/core/pipeline/policies/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index 8c03ad5e399c..68cea792d4d7 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -25,7 +25,7 @@ # -------------------------------------------------------------------------- import datetime import email.utils -from typing import Optional, cast, Union +from typing import Optional, cast, Union, Tuple from urllib.parse import urlparse from azure.core.pipeline.transport import ( @@ -133,7 +133,7 @@ def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: return None -def get_next_challenge(header_value: str) -> Optional[tuple[str, str]]: +def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]: """ Iterates through the challenge schemes present in a challenge header. @@ -152,7 +152,7 @@ def get_next_challenge(header_value: str) -> Optional[tuple[str, str]]: return challenge_key, header_value -def get_next_parameter(header_value: str, separator: str = '=') -> Optional[tuple[str, str, str]]: +def get_next_parameter(header_value: str, separator: str = '=') -> Optional[Tuple[str, str, str]]: """ Iterates through a challenge header value to extract key-value parameters. From fa6542779ef080d9cf5844a2596be584d06b5f88 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Mon, 30 Sep 2024 16:21:52 -0700 Subject: [PATCH 04/12] update --- .../azure/core/pipeline/policies/_authentication.py | 2 ++ .../core/pipeline/policies/_authentication_async.py | 4 +++- .../azure-core/azure/core/pipeline/policies/_utils.py | 10 ++++++++-- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 4af989c271d3..509df6a5fff8 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -201,6 +201,8 @@ def on_challenge( error = get_challenge_parameter(headers, "Bearer", "error") if error == "insufficient_claims": encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + if not encoded_claims: + return False claims = base64.urlsafe_b64decode(encoded_claims).decode("utf-8") if claims: try: diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index e3456bfa0b2a..ac51db7ffcdf 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -144,12 +144,14 @@ async def on_challenge( error = get_challenge_parameter(headers, "Bearer", "error") if error == "insufficient_claims": encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + if not encoded_claims: + return False claims = base64.urlsafe_b64decode(encoded_claims).decode("utf-8") if claims: try: token = await self._get_token(*self._scopes, claims=claims) bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token - self._update_headers(request.http_request.headers, bearer_token) + request.http_request.headers["Authorization"] = "Bearer " + bearer_token return True except Exception as ex: # pylint:disable=broad-except return False diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index 68cea792d4d7..1cd5aaccc513 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -123,10 +123,16 @@ def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: # Iterate through each challenge value. while get_next_challenge(header_span): - challenge_key, header_span = get_next_challenge(header_span) + challenge = get_next_challenge(header_span) + if not challenge: + break + challenge_key, header_span = challenge # Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge. while get_next_parameter(header_span): - key, value, header_span = get_next_parameter(header_span) + parameters = get_next_parameter(header_span) + if not parameters: + break + key, value, header_span = parameters if challenge_key.lower() == scheme.lower() and key.lower() == parameter.lower(): return value From 6b125599cc64b6729a8840aded8340a9122848b3 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Mon, 30 Sep 2024 16:59:57 -0700 Subject: [PATCH 05/12] update --- .../core/pipeline/policies/_authentication.py | 7 +++---- .../pipeline/policies/_authentication_async.py | 7 +++---- .../azure/core/pipeline/policies/_utils.py | 15 +++++++++------ 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 509df6a5fff8..4fc2623f2567 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -95,10 +95,9 @@ def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "Acces if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member options[key] = kwargs.pop(key) # type: ignore[literal-required] - token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) else: - token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) - return token + return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) def _request_token(self, *scopes: str, **kwargs: Any) -> None: """Request a new token from the credential. @@ -210,7 +209,7 @@ def on_challenge( bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token self._update_headers(request.http_request.headers, bearer_token) return True - except Exception as ex: # pylint:disable=broad-except + except Exception: # pylint:disable=broad-except return False return False diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index ac51db7ffcdf..e776430f8651 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -153,7 +153,7 @@ async def on_challenge( bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token request.http_request.headers["Authorization"] = "Bearer " + bearer_token return True - except Exception as ex: # pylint:disable=broad-except + except Exception: # pylint:disable=broad-except return False return False @@ -197,18 +197,17 @@ async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member options[key] = kwargs.pop(key) # type: ignore[literal-required] - token = await await_result( + return await await_result( cast(AsyncSupportsTokenInfo, self._credential).get_token_info, *scopes, options=options, ) else: - token = await await_result( + return await await_result( cast(AsyncTokenCredential, self._credential).get_token, *scopes, **kwargs, ) - return token async def _request_token(self, *scopes: str, **kwargs: Any) -> None: """Request a new token from the credential. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index 1cd5aaccc513..5cd65d55a3fb 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -108,10 +108,11 @@ def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: """ Parses the specified parameter from a challenge header found in the response. - :param headers: The response headers to parse. - :param challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer". - :param challenge_parameter: The parameter key name to search for. + :param dict[str, str] headers: The response headers to parse. + :param str challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer". + :param str challenge_parameter: The parameter key name to search for. :return: The value of the parameter name if found. + :rtype: str or None """ header_value = headers.get("WWW-Authenticate") if not header_value: @@ -143,8 +144,9 @@ def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]: """ Iterates through the challenge schemes present in a challenge header. - :param header_value: The header value which will be sliced to remove the first parsed challenge key. + :param str header_value: The header value which will be sliced to remove the first parsed challenge key. :return: The parsed challenge scheme and the remaining header value. + :rtype: tuple[str, str] or None """ header_value = header_value.lstrip(' ') end_of_challenge_key = header_value.find(' ') @@ -162,9 +164,10 @@ def get_next_parameter(header_value: str, separator: str = '=') -> Optional[Tupl """ Iterates through a challenge header value to extract key-value parameters. - :param header_value: The header value after being parsed by get_next_challenge. - :param separator: The challenge parameter key-value pair separator, default is '='. + :param str header_value: The header value after being parsed by get_next_challenge. + :param str separator: The challenge parameter key-value pair separator, default is '='. :return: The next available challenge parameter as a tuple (param_key, param_value, remaining header_value). + :rtype: tuple[str, str, str] or None """ space_or_comma = " ," header_value = header_value.lstrip(space_or_comma) From bbfe516531eb510fdf042d4bd29168cb09e13b58 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Tue, 1 Oct 2024 08:34:38 -0700 Subject: [PATCH 06/12] update --- .../core/pipeline/policies/_authentication.py | 3 +-- .../policies/_authentication_async.py | 11 ++++----- .../azure/core/pipeline/policies/_utils.py | 20 ++++++++-------- sdk/core/azure-core/tests/test_utils.py | 23 +++++++++++++++---- 4 files changed, 34 insertions(+), 23 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 4fc2623f2567..eb8f31341ddc 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -96,8 +96,7 @@ def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "Acces options[key] = kwargs.pop(key) # type: ignore[literal-required] return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) - else: - return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) + return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) def _request_token(self, *scopes: str, **kwargs: Any) -> None: """Request a new token from the credential. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index e776430f8651..a73225a17eea 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -202,12 +202,11 @@ async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", *scopes, options=options, ) - else: - return await await_result( - cast(AsyncTokenCredential, self._credential).get_token, - *scopes, - **kwargs, - ) + return await await_result( + cast(AsyncTokenCredential, self._credential).get_token, + *scopes, + **kwargs, + ) async def _request_token(self, *scopes: str, **kwargs: Any) -> None: """Request a new token from the credential. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index 5cd65d55a3fb..29f1e2acc816 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -148,19 +148,19 @@ def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]: :return: The parsed challenge scheme and the remaining header value. :rtype: tuple[str, str] or None """ - header_value = header_value.lstrip(' ') - end_of_challenge_key = header_value.find(' ') + header_value = header_value.lstrip(" ") + end_of_challenge_key = header_value.find(" ") if end_of_challenge_key < 0: return None challenge_key = header_value[:end_of_challenge_key] - header_value = header_value[end_of_challenge_key + 1:] + header_value = header_value[end_of_challenge_key + 1 :] return challenge_key, header_value -def get_next_parameter(header_value: str, separator: str = '=') -> Optional[Tuple[str, str, str]]: +def get_next_parameter(header_value: str, separator: str = "=") -> Optional[Tuple[str, str, str]]: """ Iterates through a challenge header value to extract key-value parameters. @@ -172,7 +172,7 @@ def get_next_parameter(header_value: str, separator: str = '=') -> Optional[Tupl space_or_comma = " ," header_value = header_value.lstrip(space_or_comma) - next_space = header_value.find(' ') + next_space = header_value.find(" ") next_separator = header_value.find(separator) if next_space < next_separator and next_space != -1: @@ -182,21 +182,21 @@ def get_next_parameter(header_value: str, separator: str = '=') -> Optional[Tupl return None param_key = header_value[:next_separator].strip() - header_value = header_value[next_separator + 1:] + header_value = header_value[next_separator + 1 :] quote_index = header_value.find('"') if quote_index >= 0: - header_value = header_value[quote_index + 1:] - param_value = header_value[:header_value.find('"')] + header_value = header_value[quote_index + 1 :] + param_value = header_value[: header_value.find('"')] else: - trailing_delimiter_index = header_value.find(' ') + trailing_delimiter_index = header_value.find(" ") if trailing_delimiter_index >= 0: param_value = header_value[:trailing_delimiter_index] else: param_value = header_value if header_value != param_value: - header_value = header_value[len(param_value) + 1:] + header_value = header_value[len(param_value) + 1 :] return param_key, param_value, header_value diff --git a/sdk/core/azure-core/tests/test_utils.py b/sdk/core/azure-core/tests/test_utils.py index 1bd2abee140a..1b2de27d4725 100644 --- a/sdk/core/azure-core/tests/test_utils.py +++ b/sdk/core/azure-core/tests/test_utils.py @@ -147,13 +147,26 @@ def test_parse_retry_after(): ret = parse_retry_after("0.9") assert ret == 0.9 + def test_get_challenge_parameter(): - headers = {"WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'} - assert get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id" + headers = { + "WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"' + } + assert ( + get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id" + ) assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net" assert get_challenge_parameter(headers, "Bearer", "foo") is None - headers = {"WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="'} - assert get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/common/oauth2/authorize" + headers = { + "WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="' + } + assert ( + get_challenge_parameter(headers, "Bearer", "authorization_uri") + == "https://login.microsoftonline.com/common/oauth2/authorize" + ) assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims" - assert get_challenge_parameter(headers, "Bearer", "claims") == 'eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==' \ No newline at end of file + assert ( + get_challenge_parameter(headers, "Bearer", "claims") + == "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==" + ) From 70eb72e4434921d299b581395cb8cf7ec7da71de Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Thu, 3 Oct 2024 08:08:26 -0700 Subject: [PATCH 07/12] Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py Co-authored-by: Paul Van Eck --- sdk/core/azure-core/azure/core/pipeline/policies/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index 29f1e2acc816..052f3512caaf 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -123,7 +123,7 @@ def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: header_span = header_value # Iterate through each challenge value. - while get_next_challenge(header_span): + while True: challenge = get_next_challenge(header_span) if not challenge: break From 3899b1342733609189ac20f53d539477ccb981d1 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Thu, 3 Oct 2024 08:08:33 -0700 Subject: [PATCH 08/12] Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py Co-authored-by: Paul Van Eck --- sdk/core/azure-core/azure/core/pipeline/policies/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index 052f3512caaf..b80563f6916c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -129,7 +129,7 @@ def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: break challenge_key, header_span = challenge # Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge. - while get_next_parameter(header_span): + while True: parameters = get_next_parameter(header_span) if not parameters: break From 70dd0d1248a46dee825094d0fb57a9d870510ae9 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Thu, 3 Oct 2024 14:42:22 -0700 Subject: [PATCH 09/12] update --- .../azure/core/pipeline/policies/_utils.py | 2 ++ sdk/core/azure-core/tests/test_utils.py | 31 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index b80563f6916c..e06773fcf2de 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -128,6 +128,8 @@ def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: if not challenge: break challenge_key, header_span = challenge + if challenge_key.lower() != scheme.lower(): + continue # Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge. while True: parameters = get_next_parameter(header_span) diff --git a/sdk/core/azure-core/tests/test_utils.py b/sdk/core/azure-core/tests/test_utils.py index 1b2de27d4725..682f5ffe7d54 100644 --- a/sdk/core/azure-core/tests/test_utils.py +++ b/sdk/core/azure-core/tests/test_utils.py @@ -170,3 +170,34 @@ def test_get_challenge_parameter(): get_challenge_parameter(headers, "Bearer", "claims") == "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==" ) + + +def test_get_challenge_parameter_not_found(): + headers = { + "WWW-Authenticate": 'Pop authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"' + } + assert get_challenge_parameter(headers, "Bearer", "resource") is None + + +def test_get_multi_challenge_parameter(): + headers = { + "WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net" Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"' + } + assert ( + get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id" + ) + assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net" + assert get_challenge_parameter(headers, "Bearer", "foo") is None + + headers = { + "WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="' + } + assert ( + get_challenge_parameter(headers, "Bearer", "authorization_uri") + == "https://login.microsoftonline.com/common/oauth2/authorize" + ) + assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims" + assert ( + get_challenge_parameter(headers, "Bearer", "claims") + == "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==" + ) From 4dbc84052b9c1a0f1bdce86a7df25ccbdcd7a7eb Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Thu, 3 Oct 2024 16:23:41 -0700 Subject: [PATCH 10/12] Update sdk/core/azure-core/tests/test_utils.py Co-authored-by: Paul Van Eck --- sdk/core/azure-core/tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/core/azure-core/tests/test_utils.py b/sdk/core/azure-core/tests/test_utils.py index 682f5ffe7d54..015557dbec8e 100644 --- a/sdk/core/azure-core/tests/test_utils.py +++ b/sdk/core/azure-core/tests/test_utils.py @@ -190,7 +190,7 @@ def test_get_multi_challenge_parameter(): assert get_challenge_parameter(headers, "Bearer", "foo") is None headers = { - "WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="' + "WWW-Authenticate": 'Digest realm="foo@test.com", qop="auth,auth-int", nonce="123456abcdefg", opaque="123456", Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="' } assert ( get_challenge_parameter(headers, "Bearer", "authorization_uri") From dca906e8efcef1344b12bade2225161a8e288f2c Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Thu, 3 Oct 2024 16:23:47 -0700 Subject: [PATCH 11/12] Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py Co-authored-by: Paul Van Eck --- sdk/core/azure-core/azure/core/pipeline/policies/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index e06773fcf2de..dce2c45bc5a3 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -136,7 +136,7 @@ def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: if not parameters: break key, value, header_span = parameters - if challenge_key.lower() == scheme.lower() and key.lower() == parameter.lower(): + if key.lower() == parameter.lower(): return value return None From ce6e3290774a09ec979ee882dcbec06fb8d894ed Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Thu, 3 Oct 2024 16:29:59 -0700 Subject: [PATCH 12/12] update --- .../azure/core/pipeline/policies/_authentication.py | 13 +++++++------ .../core/pipeline/policies/_authentication_async.py | 11 ++++++----- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index eb8f31341ddc..537270038eee 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -201,15 +201,16 @@ def on_challenge( encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") if not encoded_claims: return False - claims = base64.urlsafe_b64decode(encoded_claims).decode("utf-8") - if claims: - try: + try: + padding_needed = -len(encoded_claims) % 4 + claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") + if claims: token = self._get_token(*self._scopes, claims=claims) bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token - self._update_headers(request.http_request.headers, bearer_token) + request.http_request.headers["Authorization"] = "Bearer " + bearer_token return True - except Exception: # pylint:disable=broad-except - return False + except Exception: # pylint:disable=broad-except + return False return False def on_response( diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index a73225a17eea..f97b8df3b7b2 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -146,15 +146,16 @@ async def on_challenge( encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") if not encoded_claims: return False - claims = base64.urlsafe_b64decode(encoded_claims).decode("utf-8") - if claims: - try: + try: + padding_needed = -len(encoded_claims) % 4 + claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") + if claims: token = await self._get_token(*self._scopes, claims=claims) bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token request.http_request.headers["Authorization"] = "Bearer " + bearer_token return True - except Exception: # pylint:disable=broad-except - return False + except Exception: # pylint:disable=broad-except + return False return False def on_response(